usermatrix.h

Go to the documentation of this file.
00001 // Copyright (C) 2006 Ivo Nowak and Stefan Vigerske
00002 // All Rights Reserved.
00003 // This code is published under the Common Public License.
00004 //
00005 // Author: Ivo Nowak, Stefan Vigerske
00006 
00007 // usermatrix.h
00008 
00009 #ifndef USERMATRIX_H
00010 #define USERMATRIX_H
00011 
00012 // LAGO-stuff:
00013 #include "standard.h"
00014 #include "param.h"
00015 
00016 #ifdef LANCZOS_AVAILABLE
00017 // Helmberg-stuff:
00018 #include <hel_tools/matrix.h>
00019 #include <hel_tools/lanczpol.h>
00020 #else
00021 typedef int Integer;
00022 typedef double Real;
00023 inline double d_sign(double a,double b) { return (((b>=0)&&(a>=0))||((b<0)&&(a<0)))?a:-a; }
00024 #endif
00025 
00026 // TNT-stuff:
00027 #include <tnt_cmat.h>
00028 #include <tnt_vec.h>
00029 
00035 class UserMatrix
00036 #ifdef LANCZOS_AVAILABLE
00037 : public Lanczosmatrix
00038 #endif
00039 {
00047       friend ostream& operator << (ostream& out, const UserMatrix& a)
00048       {  a.print(out); return out; };
00049 
00050    protected:
00053       int dim_;
00054 
00055    public:
00059       UserMatrix(int n=0)
00060       : dim_(n)
00061       { }
00062 
00065       virtual ~UserMatrix() { }
00066 
00070       int dim() const { return dim_; };
00071 
00078       virtual void MultV(UserVector<double>& y,const UserVector<double>& x) const=0;
00079 
00085       virtual void MultV(double* y, const double* x) const {
00086         dvector xx(x, dim()), yy(dim());
00087         MultV(yy,xx);
00088         for (int i=0; i<dim(); i++) y[i]=yy[i];
00089       }
00090 
00097       virtual double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
00098         Pointer<UserVector<double> > z(y.getemptycopy());
00099         MultV(*z,x);
00100         return *z*y;
00101       }
00102 
00108       virtual double xAx(const UserVector<double>& x) const
00109       {  return yAx(x,x);  };
00110 
00116       virtual void AddMult(UserVector<double>& y, const UserVector<double>& x, const double val) const {
00117         y.AddMult(val, (*this * x));
00118       };
00119 
00125       dvector operator*(const dvector& x) const {
00126         dvector y(x.size());
00127         MultV(y, x);
00128         return y;
00129       }
00130 
00136       double* operator*(const double* x) const {
00137         double* y=new double[dim()];
00138         MultV(y,x);
00139         return y;
00140       }
00141 
00142       Pointer<UserVector<double> > operator*(const Pointer<UserVector<double> > x) const {
00143         Pointer<UserVector<double> > y(x->getemptycopy());
00144         MultV(*y,*x);
00145         return y;
00146       }
00147 
00148 #ifdef LANCZOS_AVAILABLE
00149 
00153       Integer lanczosdim() const { return dim(); }
00154 
00162       int lanczosmult(const Matrix& x, Matrix &y) const {
00163         for (int i=0; i<x.coldim(); i++)
00164           ((UserMatrix*)this)->MultV(y.get_store()+i*x.rowdim(), x.get_store()+i*x.rowdim());
00165         y*=-1;
00166         return 0;
00167       }
00168 
00177       int eig_lanczos(vector<dvector>& eig_vec, vector<double>& eig_val, Param *param=NULL) const;
00178 
00186       int eig_lanczos(dvector& eig_vec, double& eig_val, Param *param=NULL) const;
00187 #endif
00188 
00189 #ifdef ARPACK_AVAILABLE
00190 
00195       void arpackmult(double* x, double* y) { MultV(y,x); }
00196 
00203       int eig_arpack(dvector& eig_vec, double& eig_val, Param* param=NULL);
00204 
00211       int eig_arpack(vector<dvector>& eig_vec, vector<double>& eig_val, Param* param=NULL);
00212 #endif
00213 
00222                         virtual int eig_ql(vector<dvector> &eig_vec, vector<double> &eig_val) const;
00223 
00235       int eig(vector<dvector>& eig_vec, vector<double>& eig_val, Param *param=NULL);
00236 
00244       int eig(dvector& eig_vec, double& eig_val, Param *param=NULL);
00245 
00251       virtual void print(ostream &out) const;
00252 };
00253 
00254 #ifdef FILIB_AVAILABLE
00255 class IntervalCompliantMatrix : public UserMatrix {
00256         public:
00257                 IntervalCompliantMatrix(int n=0)
00258                 : UserMatrix(n)
00259                 { }
00260 
00261                 virtual void MultV(IntervalVector& y, const IntervalVector& x) const=0;
00262 
00263                 using UserMatrix::MultV;
00264 
00265                 virtual interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const {
00266                         IntervalVector Ax(x.dim()); this->MultV(Ax,x);
00267                         return y*Ax;
00268                 }
00269                 using UserMatrix::yAx;
00270 
00271                 virtual interval<double> xAx(const IntervalVector& x) const { return yAx(x, x); }
00272                 using UserMatrix::xAx;
00273 };
00274 #endif
00275 
00278 class ExtUserMatrix
00279 #ifdef FILIB_AVAILABLE
00280 : public IntervalCompliantMatrix
00281 #else
00282 : public UserMatrix
00283 #endif
00284 {
00285   public:
00289     ExtUserMatrix(int n=0)
00290 #ifdef FILIB_AVAILABLE
00291                 : IntervalCompliantMatrix(n)
00292 #else
00293                 : UserMatrix(n)
00294 #endif
00295     { }
00296 
00303     virtual double operator()(int row, int col) const=0;
00304 
00305 #ifdef FILIB_AVAILABLE
00306 
00308                 virtual interval<double> xAx_2bx(const IntervalVector& x, const UserVector<double>& b) const {
00309                         return xAx(x)+interval<double>(2,2)*(x*b);
00310                 }
00311 #endif
00312                 
00313 };
00314 
00318 class ShiftMatrix : public UserMatrix {
00319   protected:
00322     Pointer<UserMatrix> A;
00323 
00326     dvector shift;
00327 
00328   public:
00333     ShiftMatrix(Pointer<UserMatrix> A_, const dvector& shift_)
00334     : UserMatrix(shift_.dim()), A(A_), shift(shift_)
00335     { }
00336 
00342     ShiftMatrix(int n, const dvector& shift_)
00343     : UserMatrix(n), A(0), shift(shift_)
00344     { }
00345 
00350     void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00351       if (A) A->MultV(y,x);
00352       else y=0;
00353       y+=x.diagmult(shift);
00354                 }
00355 
00361     void MultV(double* y, const double* x) const {
00362       if (A) {
00363         A->MultV(y,x);
00364         for (int i=0; i<dim(); i++) y[i]+=shift(i)*x[i];
00365       }
00366       else for (int i=0; i<dim(); i++) y[i]=shift(i)*x[i];
00367     }
00368 
00374     double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
00375       return (A ? A->yAx(y,x) : 0) + y * x.diagmult(shift);
00376     }
00377 
00382     void print(ostream &out) const {
00383       out << "ShiftMatrix: dim=" << dim() << " shift=" << shift << endl;
00384       if  (A) out << *A;
00385     }
00386 };
00387 
00391 class MinusMatrix
00392 #ifdef FILIB_AVAILABLE
00393 : public IntervalCompliantMatrix
00394 #else
00395 : public UserMatrix
00396 #endif
00397 {
00398   protected:
00401     Pointer<UserMatrix> A;
00402 
00403   public:
00407     MinusMatrix(Pointer<UserMatrix> A_)
00408 #ifdef FILIB_AVAILABLE
00409                 : IntervalCompliantMatrix
00410 #else
00411                 : UserMatrix
00412 #endif
00413     (A_ ? A_->dim() : 0), A(A_)
00414     { assert(A_ != NULL);
00415     }
00416 
00421     void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00422       A->MultV(y,x);
00423       y*=-1;
00424     }
00425                 
00431     void MultV(double* y, const double* x) const {
00432       A->MultV(y,x);
00433       for (int i=0; i<dim(); i++) y[i]*=-1;
00434     }
00435 
00441     double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
00442       return -A->yAx(y,x);
00443     }
00444 
00445 #ifdef FILIB_AVAILABLE
00446                 virtual void MultV(IntervalVector& y, const IntervalVector& x) const {
00447                         assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)A));
00448                         ((const IntervalCompliantMatrix*)(const UserMatrix*)A)->MultV(y,x);
00449                         y*=-1.;
00450                 }
00451 
00452                 virtual interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const {
00453                         IntervalVector Ax(x.dim());
00454                         assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)A));
00455                         ((const IntervalCompliantMatrix*)(const UserMatrix*)A)->MultV(Ax,x);
00456                         return -(y*Ax);
00457                 }
00458 #endif
00459 
00464     void print(ostream& out) const {
00465       out << "MinusMatrix: dim=" << dim() << endl << *A;
00466     }
00467 };
00468 
00472 class SumMatrix
00473 #ifdef FILIB_AVAILABLE
00474 : public IntervalCompliantMatrix
00475 #else
00476 : public UserMatrix
00477 #endif
00478 {
00479   public:
00480     Pointer<const UserMatrix> A, B;
00481     double a, b;
00482 
00490     SumMatrix(Pointer<const UserMatrix> A_, Pointer<const UserMatrix> B_=NULL, double a_=1., double b_=1.)
00491 #ifdef FILIB_AVAILABLE
00492                 : IntervalCompliantMatrix
00493 #else
00494                 : UserMatrix
00495 #endif
00496     (A_ ? A_->dim() : (B_ ? B_->dim() : 0)), A(A_), B(B_), a(a_), b(b_)
00497     { assert(A_ || B_);
00498     }
00499 
00500     SumMatrix(Pointer<UserMatrix> A_, Pointer<UserMatrix> B_=NULL, double a_=1., double b_=1.)
00501 #ifdef FILIB_AVAILABLE
00502                 : IntervalCompliantMatrix
00503 #else
00504                 : UserMatrix
00505 #endif
00506     (A_ ? A_->dim() : (B_ ? B_->dim() : 0)), A(A_), B(B_), a(a_), b(b_)
00507     { assert(A_ || B_);
00508     }
00509 
00510     SumMatrix(UserMatrix* A_, UserMatrix* B_=NULL, double a_=1., double b_=1.)
00511 #ifdef FILIB_AVAILABLE
00512                 : IntervalCompliantMatrix
00513 #else
00514                 : UserMatrix
00515 #endif
00516     (A_ ? A_->dim() : (B_ ? B_->dim() : 0)), A(A_), B(B_), a(a_), b(b_)
00517     { assert(A_ || B_);
00518     }
00519 
00520                 void MultV(dvector& y, const dvector& x) const {
00521                         if (A) {
00522                                 A->MultV(y,x);
00523                                 y*=a;
00524                                 if (B) {
00525                                   dvector y2(y.dim());
00526                                   B->MultV(y2,x);
00527                                         y.AddMult(b, y2);
00528                                 }
00529                         }       else { // no A, so there is a B
00530                           B->MultV(y,x);
00531                           y*=b;
00532                         }
00533     }
00534 
00535     void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00536       if (A) {
00537         A->MultV(y,x);
00538         y*=a;
00539         if (B) {
00540           Pointer<UserVector<double> > y2(y.getemptycopy());
00541           B->MultV(*y2, x);
00542           y.AddMult(b, *y2);
00543         }
00544       } else {
00545         B->MultV(y,x);
00546         y*=b;
00547       }
00548     }
00549                 using UserMatrix::MultV;
00550 
00551 #ifdef FILIB_AVAILABLE
00552                 virtual void MultV(IntervalVector& y, const IntervalVector& x) const {
00553                         if (A) assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)A));
00554                         if (B) assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)B));
00555                         if (A) {
00556                                 ((const IntervalCompliantMatrix*)(const UserMatrix*)A)->MultV(y,x);
00557                                 y*=a;
00558                                 if (B) {
00559                                   IntervalVector y2(y.dim());
00560                                   ((const IntervalCompliantMatrix*)(const UserMatrix*)B)->MultV(y2,x);
00561                                         y.AddMult(interval<double>(b), y2);
00562                                 }
00563                         }       else { // no A, so there is a B
00564                           ((const IntervalCompliantMatrix*)(const UserMatrix*)B)->MultV(y,x);
00565                           y*=b;
00566                         }
00567                 }
00568 #endif
00569 
00570     void print(ostream& out) const {
00571       out << "SumMatrix: dim=" << dim() << " a= " << a << " b= " << b << " A= ";
00572       if (A) out << endl << *A; else out << " NULL ";
00573       out << " B= ";
00574       if (B) out << endl << *B; else out << " NULL " << endl;
00575     }
00576 
00577 };
00578 
00581 class DiagMatrix: public ExtUserMatrix {
00582    public:
00585       bool one;
00586 
00589       Pointer<UserVector<double> > diag;
00590 
00595       DiagMatrix(Pointer<UserVector<double> > b_)
00596       : ExtUserMatrix(b_->dim()), diag(b_), one(*b_==1)
00597       { }
00598 
00603       void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00604         if (one) y=x;
00605         else diag->diagmult(y,x);
00606       }
00607 
00608 
00609 #ifdef FILIB_AVAILABLE
00610                         void MultV(IntervalVector& y, const IntervalVector& x) const {
00611                                 if (one) y=x;
00612                                 else for (int i=0; i<dim(); i++) (y[i]=x(i))*=(*diag)(i);
00613                         }
00614 #endif
00615 
00616 #ifdef FILIB_AVAILABLE
00617                         using IntervalCompliantMatrix::MultV;
00618 #else
00619                         using UserMatrix::MultV;
00620 #endif
00621 
00627       double yAx(const UserVector<double>& y,const UserVector<double>& x) const {
00628         if (one) return y*x;
00629         Pointer<UserVector<double> > z(y.getemptycopy());
00630         diag->diagmult(*z, x);
00631         return *z*y;
00632       }
00633 
00634 #ifdef FILIB_AVAILABLE
00635                         using IntervalCompliantMatrix::yAx;
00636 #else
00637                         using UserMatrix::yAx;
00638 #endif
00639       
00640                         double operator()(int row, int col) const {
00641         if (row==col)
00642           if (one) return 1.;
00643           else return (*diag)(row);
00644         else return 0;
00645       }
00646 
00651       void print(ostream &out) const {
00652         out << "DiagMatrix: dim: " << dim() << " Diag: " << *diag;
00653       }
00654 };
00655 
00656 class BlockMatrix
00657 #ifdef FILIB_AVAILABLE
00658 : public IntervalCompliantMatrix
00659 #else
00660 : public UserMatrix
00661 #endif
00662 {
00663   public:
00666     vector<Pointer<UserMatrix> > A;
00667 
00670     vector<ivector> block;
00671 
00672                 void set_dim() {
00673                         dim_=0;
00674                         for (int i=0; i<block.size(); i++) dim_+=block[i].size();
00675                 }
00676                 
00680     BlockMatrix(int n=0)
00681 #ifdef FILIB_AVAILABLE
00682                 : IntervalCompliantMatrix(n)
00683 #else
00684                 : UserMatrix(n)
00685 #endif
00686     { };
00687 
00688 
00692     BlockMatrix(const vector<ivector>& block_)
00693 #ifdef FILIB_AVAILABLE
00694                 : IntervalCompliantMatrix(),
00695 #else
00696                 : UserMatrix(),
00697 #endif
00698     block(block_), A(block_.size())
00699     { set_dim(); }
00700 
00705     BlockMatrix(const vector<ivector>& block_, const vector<Pointer<UserMatrix> >& A_)
00706 #ifdef FILIB_AVAILABLE
00707                 : IntervalCompliantMatrix(),
00708 #else
00709                 : UserMatrix(),
00710 #endif
00711     block(block_), A(A_)
00712     { set_dim(); }
00713 
00717     BlockMatrix(Pointer<UserMatrix> A_)
00718 #ifdef FILIB_AVAILABLE
00719                 : IntervalCompliantMatrix(A_ ? A_->dim() : 0),
00720 #else
00721                 : UserMatrix(A_ ? A_->dim() : 0),
00722 #endif
00723     block(1), A(1)
00724     { block[0].resize(dim_);
00725       for (int i=0; i<dim_; i++) block[0][i]=i;
00726       A[0]=A_;
00727     }
00728 
00732     BlockMatrix(const BlockMatrix& b)
00733 #ifdef FILIB_AVAILABLE
00734                 : IntervalCompliantMatrix(b.dim()),
00735 #else
00736                 : UserMatrix(b.dim()),
00737 #endif
00738     A(b.A), block(b.block)
00739     { }
00740 
00745     void MultV(UserVector<double>& y, const UserVector<double>& x) const;
00746                 using UserMatrix::MultV;
00747 
00754     double yAx(const UserVector<double>& y,const UserVector<double>& x) const;
00755 
00756 #ifdef FILIB_AVAILABLE
00757                 void MultV(IntervalVector& y, const IntervalVector& x) const;
00758 
00759                 interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const;
00760 #endif
00761 
00766     Pointer<UserMatrix> operator[](int k) const {
00767       return A[k];
00768     }
00769 
00774     void print(ostream &out) const {
00775       out << "BlockMatrix: dim: " << dim() << " Blocks: " << block.size() << endl;
00776       for (int i=0; i<block.size(); i++) {
00777         out << "block " << i << ": " << block[i];
00778         if (A[i]) out << *A[i];
00779       }
00780     }
00781 
00782 };
00783 
00784 //--------------------------------------------------------
00785 
00788 class DenseMatrix : public ExtUserMatrix {
00789   private:
00809     void tred2(int nm, int n, double* a, double *d, double *e, double *z) const;
00810 
00834     int imtql2(int nm, int n, double *d, double *e, double *z) const;
00835 
00836   protected:
00839     TNT::Matrix<double> A;
00840 
00841   public:
00845     bool allow_destroy;
00846 
00853     DenseMatrix(bool allow_destroy_=false)
00854     : ExtUserMatrix(), A(), allow_destroy(allow_destroy_)
00855     { }
00856 
00865     DenseMatrix(int n, double val=0., bool allow_destroy_=false)
00866     : ExtUserMatrix(n), A(n, n, val), allow_destroy(allow_destroy_)
00867     { }
00868 
00878     DenseMatrix(const UserMatrix& A_, bool allow_destroy_=false);
00879 
00884     DenseMatrix(const ExtUserMatrix& A_, bool allow_destroy_=false)
00885     : ExtUserMatrix(A_.dim()), A(A_.dim(), A_.dim(), 0.), allow_destroy(allow_destroy_)
00886     { for (int i=0; i<dim(); i++)
00887                 for (int j=0; j<dim(); j++)
00888                   A[i][j]=A_(i,j);
00889     }
00890 
00898     DenseMatrix(const TNT::Matrix<double>& A_, bool allow_destroy_=false)
00899     : ExtUserMatrix(A_.num_rows()), A(A_), allow_destroy(allow_destroy_)
00900     { }
00901 
00905     DenseMatrix(const DenseMatrix& D)
00906     : ExtUserMatrix(D.dim()), A(D.A), allow_destroy(D.allow_destroy)
00907     { }
00908 
00912     operator double*() {
00913       return *(double**)A;
00914     }
00915 
00921     DenseMatrix& operator=(const DenseMatrix& D) {
00922       if (this != &D) {
00923         dim_=D.dim();
00924         A=D.A;
00925         allow_destroy=D.allow_destroy;
00926       }
00927       return *this;
00928     }
00929 
00934     DenseMatrix& operator=(const double scalar) {
00935       A=scalar;
00936       return *this;
00937     }
00938 
00945     dvector operator[](int i) const {
00946       return dvector(A[i], dim());
00947     }
00948 
00953 /*    void set_row(dvector &a, const int i) {
00954       a=operator[](i);
00955     }
00956 */
00963     double& operator()(int row, int col) {
00964       return A[row][col];
00965     }
00966 
00967     double operator()(int row, int col) const {
00968       return A[row][col];
00969     }
00970 
00976     DenseMatrix operator+(const DenseMatrix &B) const {
00977       return DenseMatrix(A + B.A);
00978     }
00979 
00985     DenseMatrix operator-(const DenseMatrix &B) const {
00986       return DenseMatrix(A - B.A);
00987     }
00988 
00995     DenseMatrix operator*(const DenseMatrix &B) const {
00996       return DenseMatrix(A * B.A);
00997     }
00998 
01005     void MultV(UserVector<double>& y_, const UserVector<double>& x_) const;
01006 
01007 #ifdef FILIB_AVAILABLE
01008                 void MultV(IntervalVector& y, const IntervalVector& x) const;
01009 #endif
01010 
01017     void MultV(double* y_, const double* x_) const;
01018 
01019     void set_random(const dvector& lambda);
01020 
01030     int eig_ql(vector<dvector> &eig_vec, vector<double> &eig_val) const;
01031 
01039     int eig_ql(dvector &eig_vec, double &eig_val) const;
01040 
01045     void print(ostream &out) const {
01046       out << "DenseMatrix: " << A;
01047     }
01048 };
01049 
01050 class SparseMatrix {
01051   protected:
01055     map<pair<int,int>, double> values;
01056         
01060         int nz;
01064         double* val;
01068         int* row_ind;
01072         int* col_ptr;
01073 
01076                 int rows_, cols_;       
01077 
01078   public:
01083     SparseMatrix(int rows__, int cols__)
01084     : rows_(rows__), cols_(cols__), val(0), row_ind(0), col_ptr(0), nz(-1)
01085     { }
01086 
01090     SparseMatrix(const SparseMatrix& A_)
01091                 : rows_(A_.rows_), cols_(A_.cols_), nz(A_.nz),
01092                   val(A_.nz>=0 ? new double[A_.nz] : NULL),
01093                   row_ind(A_.nz>=0 ? new int[A_.nz] : NULL),
01094                   col_ptr(A_.nz>=0 ? new int[A_.cols_+1] : NULL)
01095     { if (nz>0) {
01096                 memcpy(val, A_.val, nz * sizeof(double));
01097               memcpy(row_ind, A_.row_ind, nz * sizeof(double));
01098             memcpy(col_ptr, A_.col_ptr, (cols_+1) * sizeof(double));
01099           }
01100     }
01101 
01106                 SparseMatrix(const UserMatrix& A_, bool no_finish=false);
01107 
01112                 SparseMatrix(const ExtUserMatrix& A_, bool no_finish=false)
01113                 : rows_(A_.dim()), cols_(A_.dim()), val(NULL), row_ind(NULL), col_ptr(NULL), nz(-1)
01114                 { for (int i=0; i<rows_; i++)
01115                                 for (int j=0; j<cols_; j++)
01116                                         AddElement(i,j,A_(i,j));
01117                         if (!no_finish) finish();
01118                 }
01119                 
01123                 virtual ~SparseMatrix() {
01124                   if (val) delete[] val;
01125                   if (row_ind) delete[] row_ind;
01126                   if (col_ptr) delete[] col_ptr;
01127                 }
01128 
01132                 void resize(int rows__, int cols__);
01133 
01134                 int rows() const { return rows_; }
01135                 int cols() const { return cols_; }
01136                 
01137                 const int* GetRowInd() const { return row_ind; }
01138                 const int* GetColPtr() const { return col_ptr; }
01139                 const double* GetVal() const { return val; }
01140                 double* GetVal() { return val; }
01141                 
01142                 virtual double operator()(int row, int col) const {
01143 #ifndef NO_SPARSEMATRIX_ASSERTS
01144                         assert(val!=NULL);
01145 #endif
01146                   for (int i=col_ptr[col]; i<col_ptr[col+1]; i++)
01147                     if (row_ind[i]==row) return val[i];
01148                   return 0.;
01149                 }
01150                 
01151                 int nonzeros() const {
01152                         if (nz>=0) return nz;
01153                         else return values.size();
01154                 }
01155 
01163     void AddElement(int row, int col, double v, bool check_zero=true) {
01164 #ifndef NO_SPARSEMATRIX_ASSERTS
01165       assert(val==NULL);
01166       assert(0<=row && row<rows_);
01167       assert(0<=col && col<cols_);
01168 #endif
01169       if (check_zero && fabs(v)<rtol) return;
01170       values.insert(pair<pair<int,int>, double>(pair<int,int>(col, row), v));
01171     }
01172 
01173     void AddToElement(int row, int col, double v, bool check_zero=true) {
01174 #ifndef NO_SPARSEMATRIX_ASSERTS
01175       assert(val==NULL);
01176       assert(0<=row && row<rows_);
01177       assert(0<=col && col<cols_);
01178 #endif
01179                         if (check_zero && fabs(v)<rtol) return;
01180                         map<pair<int,int>, double >::iterator it(values.find(pair<int,int>(col, row)));
01181 
01182                         if (it==values.end()) values.insert(pair<pair<int,int>, double>(pair<int,int>(col, row), v));
01183                         else it->second+=v;
01184     }
01185 
01192     SparseMatrix& operator+=(const UserMatrix& A_);
01193 
01199     SparseMatrix& operator+=(const ExtUserMatrix& A_) {
01200 #ifndef NO_SPARSEMATRIX_ASSERTS
01201                         assert(rows_==A_.dim());
01202                         assert(cols_==A_.dim());
01203 #endif
01204         for (int row=0; row<rows_; row++)
01205                 for (int col=0; col<cols_; col++)
01206                         AddToElement(row, col, A_(row,col));
01207                         return *this;
01208     }
01209 
01210     SparseMatrix& operator=(const double v);
01211 
01212     SparseMatrix& operator*=(const double v);
01213 
01214                 void set_block(const SparseMatrix& A, const ivector& indices);
01215 
01220     void finish();
01221 
01226                 virtual void MultV(double* y, const double* x) const {
01227 #ifndef NO_SPARSEMATRIX_ASSERTS
01228       assert(val);
01229 #endif
01230                   memset(y, 0, rows_ * sizeof(double));
01231                   int j=0;
01232                   for (int col=0; col<cols_; col++, x++)
01233                     for (; j<col_ptr[col+1]; j++)
01234                       y[row_ind[j]] += *x * val[j];
01235                 }
01236 
01241     virtual void MultV(dvector& y, const dvector& x) const {
01242 #ifndef NO_SPARSEMATRIX_ASSERTS
01243       assert(y.dim()==rows_);
01244       assert(x.dim()==cols_);
01245 #endif
01246       MultV((Pointer<double>)y, (Pointer<double>)x);
01247     }
01248 
01253     virtual void MultV(UserVector<double>& y, const UserVector<double>& x) const {
01254 #ifndef NO_SPARSEMATRIX_ASSERTS
01255       assert(val);
01256                         assert(y.dim()==rows_);
01257                         assert(x.dim()==cols_);
01258 #endif
01259       y=0;
01260                   int j=0;
01261       double xi;
01262                   for (int col=0; col<cols_; col++)
01263         if (xi=x(col))
01264                     for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01265                       y[row_ind[j]] += xi * val[j];
01266                 }
01267 
01268 #ifdef FILIB_AVAILABLE
01269                 virtual void MultV(IntervalVector& y, const IntervalVector& x) const {
01270 #ifndef NO_SPARSEMATRIX_ASSERTS
01271       assert(val);
01272                         assert(y.dim()==rows_);
01273                         assert(x.dim()==cols_);
01274 #endif
01275                   int j=0;
01276                         interval<double> zero(0.);
01277                         y=zero;
01278                   for (int col=0; col<cols_; col++)
01279         if (x(col)!=zero)
01280                     for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01281                       y[row_ind[j]] += val[j] * x(col);
01282                 }
01283 
01284                 virtual interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const {
01285 #ifndef NO_SPARSEMATRIX_ASSERTS
01286       assert(val);
01287                         assert(y.dim()==rows_);
01288                         assert(x.dim()==cols_);
01289 #endif
01290       interval<double> ret(0.);
01291                         interval<double> zero(0.);
01292 
01293                   int j=0;
01294                   for (int col=0; col<cols_; col++)
01295         if (x(col)!=zero)
01296                     for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01297                       ret+=y(row_ind[j]) * x(col) * val[j];
01298 
01299       return ret;
01300     }
01301     
01302     virtual interval<double> xAx(const IntervalVector& x) const {
01303 #ifndef NO_SPARSEMATRIX_ASSERTS
01304                 assert(val);
01305                 assert(rows_==cols_);
01306                 assert(x.dim()==cols_);
01307 #endif
01308                 interval<double> ret(0.);
01309                 interval<double> zero(0.);
01310 
01311                 int j=0;
01312                 for (int col=0; col<cols_; col++)
01313                 if (x(col)!=zero) {
01314                         for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01315                                 if (row_ind[j]==col) ret+=val[j]*sqr(x(col));
01316                                 else ret+=x(row_ind[j]) * x(col) * val[j];
01317                 }
01318                 return ret;
01319     }
01320 
01321                 virtual interval<double> xAx_2bx(const IntervalVector& x, const UserVector<double>& b) const;
01322 #endif
01323 
01328                 virtual void MultV(SparseVector<double>& y, const SparseVector<double>& x) const {
01329 #ifndef NO_SPARSEMATRIX_ASSERTS
01330                   assert(val);
01331                         assert(y.dim()==rows_);
01332                         assert(x.dim()==cols_);
01333 #endif
01334                   double* y0=new double[cols_]; memset(y0, 0, cols_*sizeof(double));
01335                   int j;
01336       SparseVector<double>::VectorElement* v=x.head->next;
01337       while (v) {
01338                     for (j=col_ptr[v->index]; j<col_ptr[v->index+1]; j++)
01339                       y0[row_ind[j]] += v->value * val[j];
01340                     v=v->next;
01341                   }
01342                   y.set(y0);
01343                   delete y0;
01344                 }
01345 
01346     virtual double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
01347 #ifndef NO_SPARSEMATRIX_ASSERTS
01348       assert(val);
01349                         assert(y.dim()==rows_);
01350                         assert(x.dim()==cols_);
01351 #endif
01352       double ret=0;
01353 
01354                   int j=0;
01355       double xi;
01356                   for (int col=0; col<cols_; col++)
01357         if (xi=x(col))
01358                     for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01359                       ret+=y(row_ind[j]) * xi * val[j];
01360 
01361       return ret;
01362     }
01363 
01364     virtual double yAx(const dvector& y, const dvector& x) const {
01365 #ifndef NO_SPARSEMATRIX_ASSERTS
01366                         assert(y.dim()==rows_);
01367                         assert(x.dim()==cols_);
01368       assert(val);
01369 #endif
01370       double ret=0;
01371       double* y0=(Pointer<double>)y;
01372       double* x0=(Pointer<double>)x;
01373 
01374       int j=0;
01375       for (int col=0; col<cols_; col++, x0++)
01376         for (; j<col_ptr[col+1]; j++)
01377           ret+=y[row_ind[j]] * *x0 * val[j];
01378 
01379       return ret;
01380     }
01381 
01382     virtual double xAx(const UserVector<double>& x) const { return yAx(x,x); }
01383 
01389     virtual void AddMult(UserVector<double>& y, const UserVector<double>& x, const double alpha) const {
01390 #ifndef NO_SPARSEMATRIX_ASSERTS
01391                         assert(y.dim()==rows_);
01392                         assert(x.dim()==cols_);
01393       assert(val);
01394 #endif
01395                   int j=0;
01396       double xi;
01397                   for (int col=0; col<cols_; col++)
01398         if (xi=alpha*x(col))
01399                       for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01400                         y[row_ind[j]] += xi * val[j];
01401     };
01402 
01407     void plot(char* filename) const;
01408 
01412     virtual void print(ostream& out) const;
01413 
01414 };
01415 
01416 class SparseMatrix2 : public ExtUserMatrix, public SparseMatrix {
01417         public:
01421     SparseMatrix2(int n)
01422     : SparseMatrix(n, n), ExtUserMatrix(n)
01423     { }
01424 
01428     SparseMatrix2(const SparseMatrix& A_)
01429                 : SparseMatrix(A_), ExtUserMatrix(A_.rows())
01430     { assert(A_.rows()==A_.cols());
01431                 }
01432 
01437                 SparseMatrix2(const UserMatrix& A_, bool no_finish=false)
01438                 : SparseMatrix(A_, no_finish), ExtUserMatrix(A_.dim())
01439                 { }
01440 
01445                 SparseMatrix2(const ExtUserMatrix& A_, bool no_finish=false)
01446                 : SparseMatrix(A_, no_finish), ExtUserMatrix(A_.dim())
01447                 { }
01448                 
01449                 double operator()(int row, int col) const {     return SparseMatrix::operator()(row, col);      }
01450 
01451     SparseMatrix2& operator+=(const UserMatrix& A_) {   SparseMatrix::operator+=(A_);   return *this;   }
01452 
01453     SparseMatrix2& operator+=(const ExtUserMatrix& A_) { SparseMatrix::operator+=(A_); return *this; }
01454 
01455     SparseMatrix2& operator=(const double v) { SparseMatrix::operator=(v); return *this; }
01456 
01457     SparseMatrix2& operator*=(const double v) { SparseMatrix::operator*=(v); return *this; }
01458 
01459                 void set_block(const SparseMatrix& A, const ivector& indices) { SparseMatrix::set_block(A, indices); }
01460 
01461                 void MultV(double* y, const double* x) const { SparseMatrix::MultV(y,x); }
01462 
01463     void MultV(dvector& y, const dvector& x) const { SparseMatrix::MultV(y,x); }
01464 
01465     void MultV(UserVector<double>& y, const UserVector<double>& x) const { SparseMatrix::MultV(y,x); }
01466 
01467 #ifdef FILIB_AVAILABLE
01468     void MultV(IntervalVector& y, const IntervalVector& x) const { SparseMatrix::MultV(y,x); }
01469 #endif
01470 
01471                 void MultV(SparseVector<double>& y, const SparseVector<double>& x) const { SparseMatrix::MultV(y,x); }
01472 
01473     double yAx(const UserVector<double>& y, const UserVector<double>& x) const { return SparseMatrix::yAx(y,x); }
01474 
01475     double yAx(const dvector& y, const dvector& x) const { return SparseMatrix::yAx(y,x); }
01476 
01477 #ifdef FILIB_AVAILABLE
01478                 using IntervalCompliantMatrix::yAx;
01479 #else           
01480                 using UserMatrix::yAx;
01481 #endif
01482     
01483                 double xAx(const UserVector<double>& x) const { return SparseMatrix::xAx(x); }
01484 
01485 #ifdef FILIB_AVAILABLE
01486                 interval<double> xAx(const IntervalVector& x) const { return SparseMatrix::xAx(x); }
01487 
01488                 interval<double> xAx_2bx(const IntervalVector& x, const UserVector<double>& b) const { return SparseMatrix::xAx_2bx(x, b); }
01489 #endif
01490 
01491     virtual void AddMult(UserVector<double>& y, const UserVector<double>& x, const double alpha) const { SparseMatrix::AddMult(y,x,alpha); }
01492                 
01493                 void make_symmetric();
01494                 
01500     void set_random(int num_el, double max=1.);
01501                 
01509     static SparseMatrix2* random(int n, int num_el, double max=1.) {
01510                   SparseMatrix2* A=new SparseMatrix2(n);
01511                   A->set_random(num_el, max);
01512                   return A;
01513                 }
01514                 
01515     virtual void print(ostream& out) const { SparseMatrix::print(out); }
01516 };
01517 
01518 #endif // USERMATRIX_H

Generated on Mon Oct 20 03:12:06 2008 for LaGO by  doxygen 1.4.7