Skip to content

Commit e7f2a2a

Browse files
committed
Merge pull request pytorch#320 from yozw/trtrs
Exposing the lapack function trtrs which solves triangular systems of linear equations.
2 parents 1fc2038 + 1b0f782 commit e7f2a2a

File tree

4 files changed

+113
-0
lines changed

4 files changed

+113
-0
lines changed

generic/THLapack.c

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
TH_EXTERNC void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
77
TH_EXTERNC void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
8+
TH_EXTERNC void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
9+
TH_EXTERNC void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
810
TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
911
TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
1012
TH_EXTERNC void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
@@ -29,6 +31,7 @@ TH_EXTERNC void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau,
2931
TH_EXTERNC void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
3032

3133

34+
/* Compute the solution to a real system of linear equations A * X = B */
3235
void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info)
3336
{
3437
#ifdef USE_LAPACK
@@ -43,6 +46,23 @@ void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int
4346
return;
4447
}
4548

49+
/* Solve a triangular system of the form A * X = B or A^T * X = B */
50+
void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, real *a, int lda, real *b, int ldb, int* info)
51+
{
52+
#ifdef USE_LAPACK
53+
#if defined(TH_REAL_IS_DOUBLE)
54+
dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
55+
#else
56+
strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
57+
#endif
58+
#else
59+
THError("trtrs : Lapack library not found in compile time\n");
60+
#endif
61+
return;
62+
}
63+
64+
/* Solve overdetermined or underdetermined real linear systems involving an
65+
M-by-N matrix A, or its transpose, using a QR or LQ factorization of A */
4666
void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info)
4767
{
4868
#ifdef USE_LAPACK
@@ -56,6 +76,8 @@ void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real
5676
#endif
5777
}
5878

79+
/* Compute all eigenvalues and, optionally, eigenvectors of a real symmetric
80+
matrix A */
5981
void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info)
6082
{
6183
#ifdef USE_LAPACK
@@ -69,6 +91,8 @@ void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, rea
6991
#endif
7092
}
7193

94+
/* Compute for an N-by-N real nonsymmetric matrix A, the eigenvalues and,
95+
optionally, the left and/or right eigenvectors */
7296
void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr, real *wi, real* vl, int ldvl, real *vr, int ldvr, real *work, int lwork, int *info)
7397
{
7498
#ifdef USE_LAPACK
@@ -82,6 +106,8 @@ void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr,
82106
#endif
83107
}
84108

109+
/* Compute the singular value decomposition (SVD) of a real M-by-N matrix A,
110+
optionally computing the left and/or right singular vectors */
85111
void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info)
86112
{
87113
#ifdef USE_LAPACK

generic/THLapack.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
/* AX=B */
66
TH_API void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info);
7+
/* Solve a triangular system of the form A * X = B or A^T * X = B */
8+
TH_API void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, real *a, int lda, real *b, int ldb, int* info);
79
/* ||AX-B|| */
810
TH_API void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info);
911
/* Eigenvals */

generic/THTensorLapack.c

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,90 @@ void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
164164
THIntTensor_free(ipiv);
165165
}
166166

167+
void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a,
168+
const char *uplo, const char *trans, const char *diag)
169+
{
170+
int n, nrhs, lda, ldb, info;
171+
THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS
172+
THTensor *rb__; // working version of B matrix to be passed into lapack TRTRS
173+
174+
int clonea; // set to 1 if ra__ should be copied into ra_ at return
175+
int cloneb; // set to 1 if rb__ should be copied into rb_ at return
176+
int destroya; // set to 1 if ra__ needs to be destroyed at return
177+
int destroyb; // set to 1 if rb__ needs to be destroyed at return
178+
179+
180+
if (a == NULL || ra_ == a) /* possibly destroy the inputs */
181+
{
182+
THArgCheck(ra_->nDimension == 2, 1, "A should be 2 dimensional");
183+
ra__ = THTensor_(new)();
184+
clonea = THTensor_(lapackClone)(ra__,ra_,0);
185+
destroya = 1;
186+
}
187+
else /*we want to clone and use ra_ as computational space*/
188+
{
189+
THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional");
190+
clonea = THTensor_(lapackClone)(ra_,a,1);
191+
ra__ = ra_;
192+
destroya = 0;
193+
}
194+
if (b == NULL || rb_ == b) /* possibly destroy the inputs */
195+
{
196+
THArgCheck(rb_->nDimension == 2, 2, "B should be 2 dimensional");
197+
rb__ = THTensor_(new)();
198+
cloneb = THTensor_(lapackClone)(rb__,rb_,0);
199+
destroyb = 1;
200+
}
201+
else /*we want to clone and use rb_ as computational space*/
202+
{
203+
THArgCheck(b->nDimension == 2, 2, "B should be 2 dimensional");
204+
cloneb = THTensor_(lapackClone)(rb_,b,1);
205+
rb__ = rb_;
206+
destroyb = 0;
207+
}
208+
209+
THArgCheck(ra__->nDimension == 2, 1, "A should be 2 dimensional");
210+
THArgCheck(rb__->nDimension == 2, 2, "b should be 2 dimensional");
211+
THArgCheck(ra__->size[0] == ra__->size[1], 1, "A should be square");
212+
THArgCheck(rb__->size[0] == ra__->size[0], 2, "A,b size incompatible");
213+
214+
n = (int)ra__->size[0];
215+
nrhs = (int)rb__->size[1];
216+
lda = n;
217+
ldb = n;
218+
219+
THLapack_(trtrs)(uplo[0], trans[0], diag[0], n, nrhs,
220+
THTensor_(data)(ra__), lda,
221+
THTensor_(data)(rb__), ldb, &info);
222+
223+
/* clean up */
224+
if (destroya)
225+
{
226+
if (clonea)
227+
{
228+
THTensor_(copy)(ra_,ra__);
229+
}
230+
THTensor_(free)(ra__);
231+
}
232+
if (destroyb)
233+
{
234+
if (cloneb)
235+
{
236+
THTensor_(copy)(rb_,rb__);
237+
}
238+
THTensor_(free)(rb__);
239+
}
240+
241+
if (info < 0)
242+
{
243+
THError("Lapack trtrs : Argument %d : illegal value", -info);
244+
}
245+
else if (info > 0)
246+
{
247+
THError("Lapack trtrs : A(%d,%d) is zero, singular A.", info,info);
248+
}
249+
}
250+
167251
void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
168252
{
169253
// Note that a = NULL is interpreted as a = ra_, and b = NULL as b = rb_.

generic/THTensorLapack.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#else
44

55
TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_);
6+
TH_API void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_, const char *uplo, const char *trans, const char *diag);
67
TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_);
78
TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobz, const char *uplo);
89
TH_API void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobvr);

0 commit comments

Comments
 (0)