00001
00002
00003
00004
00005
00006
00007
00008
00009 #ifndef USERMATRIX_H
00010 #define USERMATRIX_H
00011
00012
00013 #include "standard.h"
00014 #include "param.h"
00015
00016 #ifdef LANCZOS_AVAILABLE
00017
00018 #include <hel_tools/matrix.h>
00019 #include <hel_tools/lanczpol.h>
00020 #else
00021 typedef int Integer;
00022 typedef double Real;
00023 inline double d_sign(double a,double b) { return (((b>=0)&&(a>=0))||((b<0)&&(a<0)))?a:-a; }
00024 #endif
00025
00026
00027 #include <tnt_cmat.h>
00028 #include <tnt_vec.h>
00029
00035 class UserMatrix
00036 #ifdef LANCZOS_AVAILABLE
00037 : public Lanczosmatrix
00038 #endif
00039 {
00047 friend ostream& operator << (ostream& out, const UserMatrix& a)
00048 { a.print(out); return out; };
00049
00050 protected:
00053 int dim_;
00054
00055 public:
00059 UserMatrix(int n=0)
00060 : dim_(n)
00061 { }
00062
00065 virtual ~UserMatrix() { }
00066
00070 int dim() const { return dim_; };
00071
00078 virtual void MultV(UserVector<double>& y,const UserVector<double>& x) const=0;
00079
00085 virtual void MultV(double* y, const double* x) const {
00086 dvector xx(x, dim()), yy(dim());
00087 MultV(yy,xx);
00088 for (int i=0; i<dim(); i++) y[i]=yy[i];
00089 }
00090
00097 virtual double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
00098 Pointer<UserVector<double> > z(y.getemptycopy());
00099 MultV(*z,x);
00100 return *z*y;
00101 }
00102
00108 virtual double xAx(const UserVector<double>& x) const
00109 { return yAx(x,x); };
00110
00116 virtual void AddMult(UserVector<double>& y, const UserVector<double>& x, const double val) const {
00117 y.AddMult(val, (*this * x));
00118 };
00119
00125 dvector operator*(const dvector& x) const {
00126 dvector y(x.size());
00127 MultV(y, x);
00128 return y;
00129 }
00130
00136 double* operator*(const double* x) const {
00137 double* y=new double[dim()];
00138 MultV(y,x);
00139 return y;
00140 }
00141
00142 Pointer<UserVector<double> > operator*(const Pointer<UserVector<double> > x) const {
00143 Pointer<UserVector<double> > y(x->getemptycopy());
00144 MultV(*y,*x);
00145 return y;
00146 }
00147
00148 #ifdef LANCZOS_AVAILABLE
00149
00153 Integer lanczosdim() const { return dim(); }
00154
00162 int lanczosmult(const Matrix& x, Matrix &y) const {
00163 for (int i=0; i<x.coldim(); i++)
00164 ((UserMatrix*)this)->MultV(y.get_store()+i*x.rowdim(), x.get_store()+i*x.rowdim());
00165 y*=-1;
00166 return 0;
00167 }
00168
00177 int eig_lanczos(vector<dvector>& eig_vec, vector<double>& eig_val, Param *param=NULL) const;
00178
00186 int eig_lanczos(dvector& eig_vec, double& eig_val, Param *param=NULL) const;
00187 #endif
00188
00189 #ifdef ARPACK_AVAILABLE
00190
00195 void arpackmult(double* x, double* y) { MultV(y,x); }
00196
00203 int eig_arpack(dvector& eig_vec, double& eig_val, Param* param=NULL);
00204
00211 int eig_arpack(vector<dvector>& eig_vec, vector<double>& eig_val, Param* param=NULL);
00212 #endif
00213
00222 virtual int eig_ql(vector<dvector> &eig_vec, vector<double> &eig_val) const;
00223
00235 int eig(vector<dvector>& eig_vec, vector<double>& eig_val, Param *param=NULL);
00236
00244 int eig(dvector& eig_vec, double& eig_val, Param *param=NULL);
00245
00251 virtual void print(ostream &out) const;
00252 };
00253
00254 #ifdef FILIB_AVAILABLE
00255 class IntervalCompliantMatrix : public UserMatrix {
00256 public:
00257 IntervalCompliantMatrix(int n=0)
00258 : UserMatrix(n)
00259 { }
00260
00261 virtual void MultV(IntervalVector& y, const IntervalVector& x) const=0;
00262
00263 using UserMatrix::MultV;
00264
00265 virtual interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const {
00266 IntervalVector Ax(x.dim()); this->MultV(Ax,x);
00267 return y*Ax;
00268 }
00269 using UserMatrix::yAx;
00270
00271 virtual interval<double> xAx(const IntervalVector& x) const { return yAx(x, x); }
00272 using UserMatrix::xAx;
00273 };
00274 #endif
00275
00278 class ExtUserMatrix
00279 #ifdef FILIB_AVAILABLE
00280 : public IntervalCompliantMatrix
00281 #else
00282 : public UserMatrix
00283 #endif
00284 {
00285 public:
00289 ExtUserMatrix(int n=0)
00290 #ifdef FILIB_AVAILABLE
00291 : IntervalCompliantMatrix(n)
00292 #else
00293 : UserMatrix(n)
00294 #endif
00295 { }
00296
00303 virtual double operator()(int row, int col) const=0;
00304
00305 #ifdef FILIB_AVAILABLE
00306
00308 virtual interval<double> xAx_2bx(const IntervalVector& x, const UserVector<double>& b) const {
00309 return xAx(x)+interval<double>(2,2)*(x*b);
00310 }
00311 #endif
00312
00313 };
00314
00318 class ShiftMatrix : public UserMatrix {
00319 protected:
00322 Pointer<UserMatrix> A;
00323
00326 dvector shift;
00327
00328 public:
00333 ShiftMatrix(Pointer<UserMatrix> A_, const dvector& shift_)
00334 : UserMatrix(shift_.dim()), A(A_), shift(shift_)
00335 { }
00336
00342 ShiftMatrix(int n, const dvector& shift_)
00343 : UserMatrix(n), A(0), shift(shift_)
00344 { }
00345
00350 void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00351 if (A) A->MultV(y,x);
00352 else y=0;
00353 y+=x.diagmult(shift);
00354 }
00355
00361 void MultV(double* y, const double* x) const {
00362 if (A) {
00363 A->MultV(y,x);
00364 for (int i=0; i<dim(); i++) y[i]+=shift(i)*x[i];
00365 }
00366 else for (int i=0; i<dim(); i++) y[i]=shift(i)*x[i];
00367 }
00368
00374 double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
00375 return (A ? A->yAx(y,x) : 0) + y * x.diagmult(shift);
00376 }
00377
00382 void print(ostream &out) const {
00383 out << "ShiftMatrix: dim=" << dim() << " shift=" << shift << endl;
00384 if (A) out << *A;
00385 }
00386 };
00387
00391 class MinusMatrix
00392 #ifdef FILIB_AVAILABLE
00393 : public IntervalCompliantMatrix
00394 #else
00395 : public UserMatrix
00396 #endif
00397 {
00398 protected:
00401 Pointer<UserMatrix> A;
00402
00403 public:
00407 MinusMatrix(Pointer<UserMatrix> A_)
00408 #ifdef FILIB_AVAILABLE
00409 : IntervalCompliantMatrix
00410 #else
00411 : UserMatrix
00412 #endif
00413 (A_ ? A_->dim() : 0), A(A_)
00414 { assert(A_ != NULL);
00415 }
00416
00421 void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00422 A->MultV(y,x);
00423 y*=-1;
00424 }
00425
00431 void MultV(double* y, const double* x) const {
00432 A->MultV(y,x);
00433 for (int i=0; i<dim(); i++) y[i]*=-1;
00434 }
00435
00441 double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
00442 return -A->yAx(y,x);
00443 }
00444
00445 #ifdef FILIB_AVAILABLE
00446 virtual void MultV(IntervalVector& y, const IntervalVector& x) const {
00447 assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)A));
00448 ((const IntervalCompliantMatrix*)(const UserMatrix*)A)->MultV(y,x);
00449 y*=-1.;
00450 }
00451
00452 virtual interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const {
00453 IntervalVector Ax(x.dim());
00454 assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)A));
00455 ((const IntervalCompliantMatrix*)(const UserMatrix*)A)->MultV(Ax,x);
00456 return -(y*Ax);
00457 }
00458 #endif
00459
00464 void print(ostream& out) const {
00465 out << "MinusMatrix: dim=" << dim() << endl << *A;
00466 }
00467 };
00468
00472 class SumMatrix
00473 #ifdef FILIB_AVAILABLE
00474 : public IntervalCompliantMatrix
00475 #else
00476 : public UserMatrix
00477 #endif
00478 {
00479 public:
00480 Pointer<const UserMatrix> A, B;
00481 double a, b;
00482
00490 SumMatrix(Pointer<const UserMatrix> A_, Pointer<const UserMatrix> B_=NULL, double a_=1., double b_=1.)
00491 #ifdef FILIB_AVAILABLE
00492 : IntervalCompliantMatrix
00493 #else
00494 : UserMatrix
00495 #endif
00496 (A_ ? A_->dim() : (B_ ? B_->dim() : 0)), A(A_), B(B_), a(a_), b(b_)
00497 { assert(A_ || B_);
00498 }
00499
00500 SumMatrix(Pointer<UserMatrix> A_, Pointer<UserMatrix> B_=NULL, double a_=1., double b_=1.)
00501 #ifdef FILIB_AVAILABLE
00502 : IntervalCompliantMatrix
00503 #else
00504 : UserMatrix
00505 #endif
00506 (A_ ? A_->dim() : (B_ ? B_->dim() : 0)), A(A_), B(B_), a(a_), b(b_)
00507 { assert(A_ || B_);
00508 }
00509
00510 SumMatrix(UserMatrix* A_, UserMatrix* B_=NULL, double a_=1., double b_=1.)
00511 #ifdef FILIB_AVAILABLE
00512 : IntervalCompliantMatrix
00513 #else
00514 : UserMatrix
00515 #endif
00516 (A_ ? A_->dim() : (B_ ? B_->dim() : 0)), A(A_), B(B_), a(a_), b(b_)
00517 { assert(A_ || B_);
00518 }
00519
00520 void MultV(dvector& y, const dvector& x) const {
00521 if (A) {
00522 A->MultV(y,x);
00523 y*=a;
00524 if (B) {
00525 dvector y2(y.dim());
00526 B->MultV(y2,x);
00527 y.AddMult(b, y2);
00528 }
00529 } else {
00530 B->MultV(y,x);
00531 y*=b;
00532 }
00533 }
00534
00535 void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00536 if (A) {
00537 A->MultV(y,x);
00538 y*=a;
00539 if (B) {
00540 Pointer<UserVector<double> > y2(y.getemptycopy());
00541 B->MultV(*y2, x);
00542 y.AddMult(b, *y2);
00543 }
00544 } else {
00545 B->MultV(y,x);
00546 y*=b;
00547 }
00548 }
00549 using UserMatrix::MultV;
00550
00551 #ifdef FILIB_AVAILABLE
00552 virtual void MultV(IntervalVector& y, const IntervalVector& x) const {
00553 if (A) assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)A));
00554 if (B) assert(dynamic_cast<const IntervalCompliantMatrix*>((const UserMatrix*)B));
00555 if (A) {
00556 ((const IntervalCompliantMatrix*)(const UserMatrix*)A)->MultV(y,x);
00557 y*=a;
00558 if (B) {
00559 IntervalVector y2(y.dim());
00560 ((const IntervalCompliantMatrix*)(const UserMatrix*)B)->MultV(y2,x);
00561 y.AddMult(interval<double>(b), y2);
00562 }
00563 } else {
00564 ((const IntervalCompliantMatrix*)(const UserMatrix*)B)->MultV(y,x);
00565 y*=b;
00566 }
00567 }
00568 #endif
00569
00570 void print(ostream& out) const {
00571 out << "SumMatrix: dim=" << dim() << " a= " << a << " b= " << b << " A= ";
00572 if (A) out << endl << *A; else out << " NULL ";
00573 out << " B= ";
00574 if (B) out << endl << *B; else out << " NULL " << endl;
00575 }
00576
00577 };
00578
00581 class DiagMatrix: public ExtUserMatrix {
00582 public:
00585 bool one;
00586
00589 Pointer<UserVector<double> > diag;
00590
00595 DiagMatrix(Pointer<UserVector<double> > b_)
00596 : ExtUserMatrix(b_->dim()), diag(b_), one(*b_==1)
00597 { }
00598
00603 void MultV(UserVector<double>& y, const UserVector<double>& x) const {
00604 if (one) y=x;
00605 else diag->diagmult(y,x);
00606 }
00607
00608
00609 #ifdef FILIB_AVAILABLE
00610 void MultV(IntervalVector& y, const IntervalVector& x) const {
00611 if (one) y=x;
00612 else for (int i=0; i<dim(); i++) (y[i]=x(i))*=(*diag)(i);
00613 }
00614 #endif
00615
00616 #ifdef FILIB_AVAILABLE
00617 using IntervalCompliantMatrix::MultV;
00618 #else
00619 using UserMatrix::MultV;
00620 #endif
00621
00627 double yAx(const UserVector<double>& y,const UserVector<double>& x) const {
00628 if (one) return y*x;
00629 Pointer<UserVector<double> > z(y.getemptycopy());
00630 diag->diagmult(*z, x);
00631 return *z*y;
00632 }
00633
00634 #ifdef FILIB_AVAILABLE
00635 using IntervalCompliantMatrix::yAx;
00636 #else
00637 using UserMatrix::yAx;
00638 #endif
00639
00640 double operator()(int row, int col) const {
00641 if (row==col)
00642 if (one) return 1.;
00643 else return (*diag)(row);
00644 else return 0;
00645 }
00646
00651 void print(ostream &out) const {
00652 out << "DiagMatrix: dim: " << dim() << " Diag: " << *diag;
00653 }
00654 };
00655
00656 class BlockMatrix
00657 #ifdef FILIB_AVAILABLE
00658 : public IntervalCompliantMatrix
00659 #else
00660 : public UserMatrix
00661 #endif
00662 {
00663 public:
00666 vector<Pointer<UserMatrix> > A;
00667
00670 vector<ivector> block;
00671
00672 void set_dim() {
00673 dim_=0;
00674 for (int i=0; i<block.size(); i++) dim_+=block[i].size();
00675 }
00676
00680 BlockMatrix(int n=0)
00681 #ifdef FILIB_AVAILABLE
00682 : IntervalCompliantMatrix(n)
00683 #else
00684 : UserMatrix(n)
00685 #endif
00686 { };
00687
00688
00692 BlockMatrix(const vector<ivector>& block_)
00693 #ifdef FILIB_AVAILABLE
00694 : IntervalCompliantMatrix(),
00695 #else
00696 : UserMatrix(),
00697 #endif
00698 block(block_), A(block_.size())
00699 { set_dim(); }
00700
00705 BlockMatrix(const vector<ivector>& block_, const vector<Pointer<UserMatrix> >& A_)
00706 #ifdef FILIB_AVAILABLE
00707 : IntervalCompliantMatrix(),
00708 #else
00709 : UserMatrix(),
00710 #endif
00711 block(block_), A(A_)
00712 { set_dim(); }
00713
00717 BlockMatrix(Pointer<UserMatrix> A_)
00718 #ifdef FILIB_AVAILABLE
00719 : IntervalCompliantMatrix(A_ ? A_->dim() : 0),
00720 #else
00721 : UserMatrix(A_ ? A_->dim() : 0),
00722 #endif
00723 block(1), A(1)
00724 { block[0].resize(dim_);
00725 for (int i=0; i<dim_; i++) block[0][i]=i;
00726 A[0]=A_;
00727 }
00728
00732 BlockMatrix(const BlockMatrix& b)
00733 #ifdef FILIB_AVAILABLE
00734 : IntervalCompliantMatrix(b.dim()),
00735 #else
00736 : UserMatrix(b.dim()),
00737 #endif
00738 A(b.A), block(b.block)
00739 { }
00740
00745 void MultV(UserVector<double>& y, const UserVector<double>& x) const;
00746 using UserMatrix::MultV;
00747
00754 double yAx(const UserVector<double>& y,const UserVector<double>& x) const;
00755
00756 #ifdef FILIB_AVAILABLE
00757 void MultV(IntervalVector& y, const IntervalVector& x) const;
00758
00759 interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const;
00760 #endif
00761
00766 Pointer<UserMatrix> operator[](int k) const {
00767 return A[k];
00768 }
00769
00774 void print(ostream &out) const {
00775 out << "BlockMatrix: dim: " << dim() << " Blocks: " << block.size() << endl;
00776 for (int i=0; i<block.size(); i++) {
00777 out << "block " << i << ": " << block[i];
00778 if (A[i]) out << *A[i];
00779 }
00780 }
00781
00782 };
00783
00784
00785
00788 class DenseMatrix : public ExtUserMatrix {
00789 private:
00809 void tred2(int nm, int n, double* a, double *d, double *e, double *z) const;
00810
00834 int imtql2(int nm, int n, double *d, double *e, double *z) const;
00835
00836 protected:
00839 TNT::Matrix<double> A;
00840
00841 public:
00845 bool allow_destroy;
00846
00853 DenseMatrix(bool allow_destroy_=false)
00854 : ExtUserMatrix(), A(), allow_destroy(allow_destroy_)
00855 { }
00856
00865 DenseMatrix(int n, double val=0., bool allow_destroy_=false)
00866 : ExtUserMatrix(n), A(n, n, val), allow_destroy(allow_destroy_)
00867 { }
00868
00878 DenseMatrix(const UserMatrix& A_, bool allow_destroy_=false);
00879
00884 DenseMatrix(const ExtUserMatrix& A_, bool allow_destroy_=false)
00885 : ExtUserMatrix(A_.dim()), A(A_.dim(), A_.dim(), 0.), allow_destroy(allow_destroy_)
00886 { for (int i=0; i<dim(); i++)
00887 for (int j=0; j<dim(); j++)
00888 A[i][j]=A_(i,j);
00889 }
00890
00898 DenseMatrix(const TNT::Matrix<double>& A_, bool allow_destroy_=false)
00899 : ExtUserMatrix(A_.num_rows()), A(A_), allow_destroy(allow_destroy_)
00900 { }
00901
00905 DenseMatrix(const DenseMatrix& D)
00906 : ExtUserMatrix(D.dim()), A(D.A), allow_destroy(D.allow_destroy)
00907 { }
00908
00912 operator double*() {
00913 return *(double**)A;
00914 }
00915
00921 DenseMatrix& operator=(const DenseMatrix& D) {
00922 if (this != &D) {
00923 dim_=D.dim();
00924 A=D.A;
00925 allow_destroy=D.allow_destroy;
00926 }
00927 return *this;
00928 }
00929
00934 DenseMatrix& operator=(const double scalar) {
00935 A=scalar;
00936 return *this;
00937 }
00938
00945 dvector operator[](int i) const {
00946 return dvector(A[i], dim());
00947 }
00948
00953
00954
00955
00956
00963 double& operator()(int row, int col) {
00964 return A[row][col];
00965 }
00966
00967 double operator()(int row, int col) const {
00968 return A[row][col];
00969 }
00970
00976 DenseMatrix operator+(const DenseMatrix &B) const {
00977 return DenseMatrix(A + B.A);
00978 }
00979
00985 DenseMatrix operator-(const DenseMatrix &B) const {
00986 return DenseMatrix(A - B.A);
00987 }
00988
00995 DenseMatrix operator*(const DenseMatrix &B) const {
00996 return DenseMatrix(A * B.A);
00997 }
00998
01005 void MultV(UserVector<double>& y_, const UserVector<double>& x_) const;
01006
01007 #ifdef FILIB_AVAILABLE
01008 void MultV(IntervalVector& y, const IntervalVector& x) const;
01009 #endif
01010
01017 void MultV(double* y_, const double* x_) const;
01018
01019 void set_random(const dvector& lambda);
01020
01030 int eig_ql(vector<dvector> &eig_vec, vector<double> &eig_val) const;
01031
01039 int eig_ql(dvector &eig_vec, double &eig_val) const;
01040
01045 void print(ostream &out) const {
01046 out << "DenseMatrix: " << A;
01047 }
01048 };
01049
01050 class SparseMatrix {
01051 protected:
01055 map<pair<int,int>, double> values;
01056
01060 int nz;
01064 double* val;
01068 int* row_ind;
01072 int* col_ptr;
01073
01076 int rows_, cols_;
01077
01078 public:
01083 SparseMatrix(int rows__, int cols__)
01084 : rows_(rows__), cols_(cols__), val(0), row_ind(0), col_ptr(0), nz(-1)
01085 { }
01086
01090 SparseMatrix(const SparseMatrix& A_)
01091 : rows_(A_.rows_), cols_(A_.cols_), nz(A_.nz),
01092 val(A_.nz>=0 ? new double[A_.nz] : NULL),
01093 row_ind(A_.nz>=0 ? new int[A_.nz] : NULL),
01094 col_ptr(A_.nz>=0 ? new int[A_.cols_+1] : NULL)
01095 { if (nz>0) {
01096 memcpy(val, A_.val, nz * sizeof(double));
01097 memcpy(row_ind, A_.row_ind, nz * sizeof(double));
01098 memcpy(col_ptr, A_.col_ptr, (cols_+1) * sizeof(double));
01099 }
01100 }
01101
01106 SparseMatrix(const UserMatrix& A_, bool no_finish=false);
01107
01112 SparseMatrix(const ExtUserMatrix& A_, bool no_finish=false)
01113 : rows_(A_.dim()), cols_(A_.dim()), val(NULL), row_ind(NULL), col_ptr(NULL), nz(-1)
01114 { for (int i=0; i<rows_; i++)
01115 for (int j=0; j<cols_; j++)
01116 AddElement(i,j,A_(i,j));
01117 if (!no_finish) finish();
01118 }
01119
01123 virtual ~SparseMatrix() {
01124 if (val) delete[] val;
01125 if (row_ind) delete[] row_ind;
01126 if (col_ptr) delete[] col_ptr;
01127 }
01128
01132 void resize(int rows__, int cols__);
01133
01134 int rows() const { return rows_; }
01135 int cols() const { return cols_; }
01136
01137 const int* GetRowInd() const { return row_ind; }
01138 const int* GetColPtr() const { return col_ptr; }
01139 const double* GetVal() const { return val; }
01140 double* GetVal() { return val; }
01141
01142 virtual double operator()(int row, int col) const {
01143 #ifndef NO_SPARSEMATRIX_ASSERTS
01144 assert(val!=NULL);
01145 #endif
01146 for (int i=col_ptr[col]; i<col_ptr[col+1]; i++)
01147 if (row_ind[i]==row) return val[i];
01148 return 0.;
01149 }
01150
01151 int nonzeros() const {
01152 if (nz>=0) return nz;
01153 else return values.size();
01154 }
01155
01163 void AddElement(int row, int col, double v, bool check_zero=true) {
01164 #ifndef NO_SPARSEMATRIX_ASSERTS
01165 assert(val==NULL);
01166 assert(0<=row && row<rows_);
01167 assert(0<=col && col<cols_);
01168 #endif
01169 if (check_zero && fabs(v)<rtol) return;
01170 values.insert(pair<pair<int,int>, double>(pair<int,int>(col, row), v));
01171 }
01172
01173 void AddToElement(int row, int col, double v, bool check_zero=true) {
01174 #ifndef NO_SPARSEMATRIX_ASSERTS
01175 assert(val==NULL);
01176 assert(0<=row && row<rows_);
01177 assert(0<=col && col<cols_);
01178 #endif
01179 if (check_zero && fabs(v)<rtol) return;
01180 map<pair<int,int>, double >::iterator it(values.find(pair<int,int>(col, row)));
01181
01182 if (it==values.end()) values.insert(pair<pair<int,int>, double>(pair<int,int>(col, row), v));
01183 else it->second+=v;
01184 }
01185
01192 SparseMatrix& operator+=(const UserMatrix& A_);
01193
01199 SparseMatrix& operator+=(const ExtUserMatrix& A_) {
01200 #ifndef NO_SPARSEMATRIX_ASSERTS
01201 assert(rows_==A_.dim());
01202 assert(cols_==A_.dim());
01203 #endif
01204 for (int row=0; row<rows_; row++)
01205 for (int col=0; col<cols_; col++)
01206 AddToElement(row, col, A_(row,col));
01207 return *this;
01208 }
01209
01210 SparseMatrix& operator=(const double v);
01211
01212 SparseMatrix& operator*=(const double v);
01213
01214 void set_block(const SparseMatrix& A, const ivector& indices);
01215
01220 void finish();
01221
01226 virtual void MultV(double* y, const double* x) const {
01227 #ifndef NO_SPARSEMATRIX_ASSERTS
01228 assert(val);
01229 #endif
01230 memset(y, 0, rows_ * sizeof(double));
01231 int j=0;
01232 for (int col=0; col<cols_; col++, x++)
01233 for (; j<col_ptr[col+1]; j++)
01234 y[row_ind[j]] += *x * val[j];
01235 }
01236
01241 virtual void MultV(dvector& y, const dvector& x) const {
01242 #ifndef NO_SPARSEMATRIX_ASSERTS
01243 assert(y.dim()==rows_);
01244 assert(x.dim()==cols_);
01245 #endif
01246 MultV((Pointer<double>)y, (Pointer<double>)x);
01247 }
01248
01253 virtual void MultV(UserVector<double>& y, const UserVector<double>& x) const {
01254 #ifndef NO_SPARSEMATRIX_ASSERTS
01255 assert(val);
01256 assert(y.dim()==rows_);
01257 assert(x.dim()==cols_);
01258 #endif
01259 y=0;
01260 int j=0;
01261 double xi;
01262 for (int col=0; col<cols_; col++)
01263 if (xi=x(col))
01264 for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01265 y[row_ind[j]] += xi * val[j];
01266 }
01267
01268 #ifdef FILIB_AVAILABLE
01269 virtual void MultV(IntervalVector& y, const IntervalVector& x) const {
01270 #ifndef NO_SPARSEMATRIX_ASSERTS
01271 assert(val);
01272 assert(y.dim()==rows_);
01273 assert(x.dim()==cols_);
01274 #endif
01275 int j=0;
01276 interval<double> zero(0.);
01277 y=zero;
01278 for (int col=0; col<cols_; col++)
01279 if (x(col)!=zero)
01280 for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01281 y[row_ind[j]] += val[j] * x(col);
01282 }
01283
01284 virtual interval<double> yAx(const IntervalVector& y, const IntervalVector& x) const {
01285 #ifndef NO_SPARSEMATRIX_ASSERTS
01286 assert(val);
01287 assert(y.dim()==rows_);
01288 assert(x.dim()==cols_);
01289 #endif
01290 interval<double> ret(0.);
01291 interval<double> zero(0.);
01292
01293 int j=0;
01294 for (int col=0; col<cols_; col++)
01295 if (x(col)!=zero)
01296 for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01297 ret+=y(row_ind[j]) * x(col) * val[j];
01298
01299 return ret;
01300 }
01301
01302 virtual interval<double> xAx(const IntervalVector& x) const {
01303 #ifndef NO_SPARSEMATRIX_ASSERTS
01304 assert(val);
01305 assert(rows_==cols_);
01306 assert(x.dim()==cols_);
01307 #endif
01308 interval<double> ret(0.);
01309 interval<double> zero(0.);
01310
01311 int j=0;
01312 for (int col=0; col<cols_; col++)
01313 if (x(col)!=zero) {
01314 for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01315 if (row_ind[j]==col) ret+=val[j]*sqr(x(col));
01316 else ret+=x(row_ind[j]) * x(col) * val[j];
01317 }
01318 return ret;
01319 }
01320
01321 virtual interval<double> xAx_2bx(const IntervalVector& x, const UserVector<double>& b) const;
01322 #endif
01323
01328 virtual void MultV(SparseVector<double>& y, const SparseVector<double>& x) const {
01329 #ifndef NO_SPARSEMATRIX_ASSERTS
01330 assert(val);
01331 assert(y.dim()==rows_);
01332 assert(x.dim()==cols_);
01333 #endif
01334 double* y0=new double[cols_]; memset(y0, 0, cols_*sizeof(double));
01335 int j;
01336 SparseVector<double>::VectorElement* v=x.head->next;
01337 while (v) {
01338 for (j=col_ptr[v->index]; j<col_ptr[v->index+1]; j++)
01339 y0[row_ind[j]] += v->value * val[j];
01340 v=v->next;
01341 }
01342 y.set(y0);
01343 delete y0;
01344 }
01345
01346 virtual double yAx(const UserVector<double>& y, const UserVector<double>& x) const {
01347 #ifndef NO_SPARSEMATRIX_ASSERTS
01348 assert(val);
01349 assert(y.dim()==rows_);
01350 assert(x.dim()==cols_);
01351 #endif
01352 double ret=0;
01353
01354 int j=0;
01355 double xi;
01356 for (int col=0; col<cols_; col++)
01357 if (xi=x(col))
01358 for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01359 ret+=y(row_ind[j]) * xi * val[j];
01360
01361 return ret;
01362 }
01363
01364 virtual double yAx(const dvector& y, const dvector& x) const {
01365 #ifndef NO_SPARSEMATRIX_ASSERTS
01366 assert(y.dim()==rows_);
01367 assert(x.dim()==cols_);
01368 assert(val);
01369 #endif
01370 double ret=0;
01371 double* y0=(Pointer<double>)y;
01372 double* x0=(Pointer<double>)x;
01373
01374 int j=0;
01375 for (int col=0; col<cols_; col++, x0++)
01376 for (; j<col_ptr[col+1]; j++)
01377 ret+=y[row_ind[j]] * *x0 * val[j];
01378
01379 return ret;
01380 }
01381
01382 virtual double xAx(const UserVector<double>& x) const { return yAx(x,x); }
01383
01389 virtual void AddMult(UserVector<double>& y, const UserVector<double>& x, const double alpha) const {
01390 #ifndef NO_SPARSEMATRIX_ASSERTS
01391 assert(y.dim()==rows_);
01392 assert(x.dim()==cols_);
01393 assert(val);
01394 #endif
01395 int j=0;
01396 double xi;
01397 for (int col=0; col<cols_; col++)
01398 if (xi=alpha*x(col))
01399 for (j=col_ptr[col]; j<col_ptr[col+1]; j++)
01400 y[row_ind[j]] += xi * val[j];
01401 };
01402
01407 void plot(char* filename) const;
01408
01412 virtual void print(ostream& out) const;
01413
01414 };
01415
01416 class SparseMatrix2 : public ExtUserMatrix, public SparseMatrix {
01417 public:
01421 SparseMatrix2(int n)
01422 : SparseMatrix(n, n), ExtUserMatrix(n)
01423 { }
01424
01428 SparseMatrix2(const SparseMatrix& A_)
01429 : SparseMatrix(A_), ExtUserMatrix(A_.rows())
01430 { assert(A_.rows()==A_.cols());
01431 }
01432
01437 SparseMatrix2(const UserMatrix& A_, bool no_finish=false)
01438 : SparseMatrix(A_, no_finish), ExtUserMatrix(A_.dim())
01439 { }
01440
01445 SparseMatrix2(const ExtUserMatrix& A_, bool no_finish=false)
01446 : SparseMatrix(A_, no_finish), ExtUserMatrix(A_.dim())
01447 { }
01448
01449 double operator()(int row, int col) const { return SparseMatrix::operator()(row, col); }
01450
01451 SparseMatrix2& operator+=(const UserMatrix& A_) { SparseMatrix::operator+=(A_); return *this; }
01452
01453 SparseMatrix2& operator+=(const ExtUserMatrix& A_) { SparseMatrix::operator+=(A_); return *this; }
01454
01455 SparseMatrix2& operator=(const double v) { SparseMatrix::operator=(v); return *this; }
01456
01457 SparseMatrix2& operator*=(const double v) { SparseMatrix::operator*=(v); return *this; }
01458
01459 void set_block(const SparseMatrix& A, const ivector& indices) { SparseMatrix::set_block(A, indices); }
01460
01461 void MultV(double* y, const double* x) const { SparseMatrix::MultV(y,x); }
01462
01463 void MultV(dvector& y, const dvector& x) const { SparseMatrix::MultV(y,x); }
01464
01465 void MultV(UserVector<double>& y, const UserVector<double>& x) const { SparseMatrix::MultV(y,x); }
01466
01467 #ifdef FILIB_AVAILABLE
01468 void MultV(IntervalVector& y, const IntervalVector& x) const { SparseMatrix::MultV(y,x); }
01469 #endif
01470
01471 void MultV(SparseVector<double>& y, const SparseVector<double>& x) const { SparseMatrix::MultV(y,x); }
01472
01473 double yAx(const UserVector<double>& y, const UserVector<double>& x) const { return SparseMatrix::yAx(y,x); }
01474
01475 double yAx(const dvector& y, const dvector& x) const { return SparseMatrix::yAx(y,x); }
01476
01477 #ifdef FILIB_AVAILABLE
01478 using IntervalCompliantMatrix::yAx;
01479 #else
01480 using UserMatrix::yAx;
01481 #endif
01482
01483 double xAx(const UserVector<double>& x) const { return SparseMatrix::xAx(x); }
01484
01485 #ifdef FILIB_AVAILABLE
01486 interval<double> xAx(const IntervalVector& x) const { return SparseMatrix::xAx(x); }
01487
01488 interval<double> xAx_2bx(const IntervalVector& x, const UserVector<double>& b) const { return SparseMatrix::xAx_2bx(x, b); }
01489 #endif
01490
01491 virtual void AddMult(UserVector<double>& y, const UserVector<double>& x, const double alpha) const { SparseMatrix::AddMult(y,x,alpha); }
01492
01493 void make_symmetric();
01494
01500 void set_random(int num_el, double max=1.);
01501
01509 static SparseMatrix2* random(int n, int num_el, double max=1.) {
01510 SparseMatrix2* A=new SparseMatrix2(n);
01511 A->set_random(num_el, max);
01512 return A;
01513 }
01514
01515 virtual void print(ostream& out) const { SparseMatrix::print(out); }
01516 };
01517
01518 #endif // USERMATRIX_H