solve.f90 Source File


Source Code

submodule(specialmatrices_tridiagonal) tridiagonal_linear_solver
   use stdlib_optval, only: optval
   use stdlib_linalg_lapack, only: gttrf, gttrs, gtrfs
   use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
                                  LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR, LINALG_SUCCESS
   implicit none(type, external)

   character(*), parameter :: this = "tridiagonal_linear_solver"
contains

   ! Process GTTRF
   elemental subroutine handle_gttrf_info(n, info, err)
      integer(ilp), intent(in) :: n, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid problem size n=", n)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for dl.")
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for d.")
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for du.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for du2.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for ipiv.")
      case (1:)
         err = linalg_state_type(this, LINALG_ERROR, "Singular matrix.")
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by gttrf")
      end select
   end subroutine handle_gttrf_info

   ! Process GTTRS
   elemental subroutine handle_gttrs_info(trans, n, nrhs, ldb, info, err)
      character, intent(in) :: trans
      integer(ilp), intent(in) :: n, nrhs, ldb, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for trans", trans)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid problem size n=", n)
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid number of rhs nrhs=", nrhs)
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for dl.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for d.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for du2.")
      case (-7)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for ipiv.")
      case (-8)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for b.")
      case (-9)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldb=", ldb)
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by gttrs")
      end select
   end subroutine handle_gttrs_info

   ! Process GTRFS
   elemental subroutine handle_gtrfs_info(trans, n, nrhs, ldb, ldx, info, err)
      character, intent(in) :: trans
      integer(ilp), intent(in) :: n, nrhs, ldb, ldx, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for trans=", trans)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for n=", n)
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for nrhs=", nrhs)
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for dl.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for d.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for du.")
      case (-7)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for dlf.")
      case (-8)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for df.")
      case (-9)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for duf.")
      case (-10)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for du2.")
      case (-11)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ipiv.")
      case (-12)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for b.")
      case (-13)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldb=", ldb)
      case (-14)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for x.")
      case (-15)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldx=", ldx)
      case (-16)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ferr.")
      case (-17)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for berr.")
      case (-18)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for work.")
      case (-19)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for iwork.")
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by gtrfs")
      end select
   end subroutine handle_gtrfs_info

   function tridiagonal_solver(A, b, refine) result(x)
      type(Tridiagonal), intent(in) :: A
      !! Coefficient matrix.
      real(dp), intent(in) :: b(:, :)
      !! Right-hand side vectors.
      logical(lk), intent(in) :: refine
      !! Iterative refinement of the solution?
      real(dp), allocatable :: x(:, :)
      !! Solution vectors.

      ! General LAPACK variables.
      integer(ilp) :: n, nrhs, info
      ! LAPACK variables for LU decomposition.
      real(dp), allocatable :: dl(:), d(:), du(:), du2(:)
      integer(ilp), allocatable :: ipiv(:)
      ! LAPACK variables for iterative refinement.
      real(dp), allocatable :: ferr(:), berr(:), work(:)
      integer(ilp), allocatable :: iwork(:)

      ! Error handler.
      type(linalg_state_type) :: err

      ! Initialize data.
      n = A%n; nrhs = size(b, 2); x = b

      !------------------------------------
      !-----     LU factorization     -----
      !------------------------------------

      ! ----- Allocations -----
      allocate (du2(n - 2), ipiv(n))
      dl = A%dl; d = A%dv; du = A%du; 
      ! ----- LU factorization -----
      call gttrf(n, dl, d, du, du2, ipiv, info)
      call handle_gttrf_info(n, info, err)

      !-------------------------------------
      !-----     Tridiagonal solve     -----
      !-------------------------------------

      ! ----- Solve the system -----
      call gttrs("N", n, nrhs, dl, d, du, du2, ipiv, x, n, info)
      call handle_gttrs_info("N", n, nrhs, n, info, err)

      !----------------------------------------
      !-----     Iterative refinement     -----
      !----------------------------------------

      if (refine) then
         ! ----- Allocate arrays -----
         allocate (ferr(nrhs), berr(nrhs), work(3*n), iwork(n))
         ! ----- Refinement step -----
         call gtrfs("N", n, nrhs, A%dl, A%dv, A%du, dl, d, du, du2, ipiv, b, &
                    n, x, n, ferr, berr, work, iwork, info)
         call handle_gtrfs_info("N", n, nrhs, n, n, info, err)
      end if
   end function tridiagonal_solver

   module procedure solve_single_rhs
   ! Local variables.
   logical(lk) :: refine_
   real(dp), pointer :: xmat(:, :), bmat(:, :)
   refine_ = optval(refine, .false.)
   x = b; xmat(1:A%n, 1:1) => x; bmat(1:A%n, 1:1) => b
   xmat = tridiagonal_solver(A, bmat, refine_)
   end procedure

   module procedure solve_multi_rhs
   ! Local variables.
   logical(lk) :: refine_
   refine_ = optval(refine, .false.)
   x = tridiagonal_solver(A, b, refine_)
   end procedure
end submodule