solve.f90 Source File


Source Code

submodule(specialmatrices_symtridiagonal) symtridiagonal_linear_solver
   use stdlib_optval, only: optval
   use stdlib_linalg_lapack, only: gttrf, gttrs, gtrfs
   use stdlib_linalg_lapack, only: pttrf, pttrs, ptrfs
   use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
                                  LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR, LINALG_SUCCESS
   implicit none(type, external)

   character(*), parameter :: this = "symtridiagonal_linear_solver"
contains

   module procedure solve_single_rhs
   ! Local variables.
   logical(lk) :: refine_
   real(dp), pointer :: xmat(:, :), bmat(:, :)
   refine_ = optval(refine, .false.)
   x = b; xmat(1:A%n, 1:1) => x; bmat(1:A%n, 1:1) => b
   if (A%isposdef) then
      xmat = posdef_symtridiagonal_solver(A, bmat, refine_)
   else
      xmat = symtridiagonal_solver(A, bmat, refine_)
   end if
   end procedure

   module procedure solve_multi_rhs
   ! Local variables.
   logical(lk) :: refine_
   refine_ = optval(refine, .false.)
   if (A%isposdef) then
      x = posdef_symtridiagonal_solver(A, b, refine_)
   else
      x = symtridiagonal_solver(A, b, refine_)
   end if
   end procedure

   !---------------------------------------------------
   !-----     Generic (Sym)Tridiagonal Solver     -----
   !---------------------------------------------------

   ! Process GTTRF
   elemental subroutine handle_gttrf_info(n, info, err)
      integer(ilp), intent(in) :: n, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid problem size n=", n)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for dl.")
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for d.")
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for du.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for du2.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid size for ipiv.")
      case (1:)
         err = linalg_state_type(this, LINALG_ERROR, "Singular matrix.")
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by gttrf")
      end select
   end subroutine handle_gttrf_info

   ! Process GTTRS
   elemental subroutine handle_gttrs_info(trans, n, nrhs, ldb, info, err)
      character, intent(in) :: trans
      integer(ilp), intent(in) :: n, nrhs, ldb, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for trans", trans)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid problem size n=", n)
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid number of rhs nrhs=", nrhs)
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for dl.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for d.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for du2.")
      case (-7)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for ipiv.")
      case (-8)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid dimensions for b.")
      case (-9)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldb=", ldb)
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by gttrs")
      end select
   end subroutine handle_gttrs_info

   ! Process GTRFS
   elemental subroutine handle_gtrfs_info(trans, n, nrhs, ldb, ldx, info, err)
      character, intent(in) :: trans
      integer(ilp), intent(in) :: n, nrhs, ldb, ldx, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for trans=", trans)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for n=", n)
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for nrhs=", nrhs)
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for dl.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for d.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for du.")
      case (-7)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for dlf.")
      case (-8)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for df.")
      case (-9)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for duf.")
      case (-10)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for du2.")
      case (-11)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ipiv.")
      case (-12)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for b.")
      case (-13)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldb=", ldb)
      case (-14)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for x.")
      case (-15)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldx=", ldx)
      case (-16)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ferr.")
      case (-17)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for berr.")
      case (-18)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for work.")
      case (-19)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for iwork.")
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by gtrfs")
      end select
   end subroutine handle_gtrfs_info

   function symtridiagonal_solver(A, b, refine) result(x)
      type(SymTridiagonal), intent(in) :: A
      !! Coefficient matrix.
      real(dp), intent(in) :: b(:, :)
      !! Right-hand side vectors.
      logical(lk), intent(in) :: refine
      !! Iterative refinement of the solution?
      real(dp), allocatable :: x(:, :)
      !! Solution vectors.

      ! General LAPACK variables.
      integer(ilp) :: n, nrhs, info
      ! LAPACK variables for LU decomposition.
      real(dp), allocatable :: dl(:), d(:), du(:), du2(:)
      integer(ilp), allocatable :: ipiv(:)
      ! LAPACK variables for iterative refinement.
      real(dp), allocatable :: ferr(:), berr(:), work(:)
      integer(ilp), allocatable :: iwork(:)

      ! Error handler.
      type(linalg_state_type) :: err

      ! Initialize data.
      n = A%n; nrhs = size(b, 2); x = b

      !------------------------------------
      !-----     LU factorization     -----
      !------------------------------------

      ! ----- Allocations -----
      allocate (du2(n - 2), ipiv(n))
      dl = A%ev; d = A%dv; du = A%ev; 
      ! ----- LU factorization -----
      call gttrf(n, dl, d, du, du2, ipiv, info)
      call handle_gttrf_info(n, info, err)

      !-------------------------------------
      !-----     Tridiagonal solve     -----
      !-------------------------------------

      ! ----- Solve the system -----
      call gttrs("N", n, nrhs, dl, d, du, du2, ipiv, x, n, info)
      call handle_gttrs_info("N", n, nrhs, n, info, err)

      !----------------------------------------
      !-----     Iterative refinement     -----
      !----------------------------------------

      if (refine) then
         ! ----- Allocate arrays -----
         allocate (ferr(nrhs), berr(nrhs), work(3*n), iwork(n))
         ! ----- Refinement step -----
         call gtrfs("N", n, nrhs, A%ev, A%dv, A%ev, dl, d, du, du2, ipiv, b, &
                    n, x, n, ferr, berr, work, iwork, info)
         call handle_gtrfs_info("N", n, nrhs, n, n, info, err)
      end if
   end function symtridiagonal_solver

   !-----------------------------------------------------------
   !-----     Positive-definite SymTridiagonal Solver     -----
   !-----------------------------------------------------------

   ! Process PTTRF
   elemental subroutine handle_pttrf_info(n, info, err)
      integer(ilp), intent(in) :: n, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid matrix dimension n=", n)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for D.")
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for E.")
      case (1:)
         err = linalg_state_type(this, LINALG_ERROR, "Matrix could not be factorized.")
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by pttrf")
      end select
   end subroutine handle_pttrf_info

   ! Process PTTRS
   elemental subroutine handle_pttrs_info(n, nrhs, ldb, info, err)
      integer(ilp), intent(in) :: n, nrhs, ldb, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for n=", n)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for nrhs=", nrhs)
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for D.")
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for E.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for B.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldb=", ldb)
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by pttrs")
      end select
   end subroutine handle_pttrs_info

   ! Process PTRFS
   elemental subroutine handle_ptrfs_info(n, nrhs, ldb, ldx, info, err)
      integer(ilp), intent(in) :: n, nrhs, ldb, ldx, info
      type(linalg_state_type), intent(inout) :: err

      select case (info)
      case (0)
         ! Success.
         err%state = LINALG_SUCCESS
      case (-1)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for n=", n)
      case (-2)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for nrhs=", nrhs)
      case (-3)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for D.")
      case (-4)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for E.")
      case (-5)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for DF.")
      case (-6)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for EF.")
      case (-7)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for B.")
      case (-8)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldb=", ldb)
      case (-9)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for X.")
      case (-10)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ldx=", ldx)
      case (-11)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for ferr.")
      case (-12)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for berr.")
      case (-13)
         err = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid value for work.")
      case default
         err = linalg_state_type(this, LINALG_INTERNAL_ERROR, "Unknown error returned by ptrfs")
      end select
   end subroutine handle_ptrfs_info

   function posdef_symtridiagonal_solver(A, b, refine) result(x)
      type(SymTridiagonal), intent(in) :: A
      !! Coefficient matrix.
      real(dp), intent(in) :: b(:, :)
      !! Right-hand side vectors.
      logical(lk), intent(in) :: refine
      !! Iterative refinement of the solution?
      real(dp), allocatable :: x(:, :)
      !! Solution vectors.

      ! General LAPACK variables.
      integer(ilp) :: n, nrhs, info
      ! LAPACK variables for LDL^T decomposition.
      real(dp), allocatable :: dv(:), ev(:)
      ! LAPACK variables for iterative refinement.
      real(dp), allocatable :: ferr(:), berr(:), work(:)

      ! Error handler.
      type(linalg_state_type) :: err

      ! Initialize data.
      n = A%n; nrhs = size(b, 2); x = b

      !------------------------------------
      !-----     LU factorization     -----
      !------------------------------------

      ! ----- Allocations -----
      ev = A%ev; dv = A%dv
      ! ----- LDL^T factorization -----
      call pttrf(n, dv, ev, info)
      call handle_pttrf_info(n, info, err)

      !-------------------------------------
      !-----     Tridiagonal solve     -----
      !-------------------------------------

      ! ----- Solve the system -----
      call pttrs(n, nrhs, dv, ev, x, n, info)
      call handle_pttrs_info(n, nrhs, n, info, err)

      !----------------------------------------
      !-----     Iterative refinement     -----
      !----------------------------------------

      if (refine) then
         ! ----- Allocate arrays -----
         allocate (ferr(nrhs), berr(nrhs), work(2*n))
         ! ----- Refinement step -----
         call ptrfs(n, nrhs, A%dv, A%ev, dv, ev, b, n, x, n, ferr, berr, work, info)
         call handle_ptrfs_info(n, nrhs, n, n, info, err)
      end if
   end function posdef_symtridiagonal_solver

end submodule