Open In App

Matrix Exponentiation

Last Updated : 08 Apr, 2025
Suggest changes
Share
Like Article
Like
Report

Matrix Exponentiation is a technique used to calculate a matrix raised to a power efficiently, that is in logN time. It is mostly used for solving problems related to linear recurrences.

Idea behind Matrix Exponentiation:

Similar to Binary Exponentiation which is used to calculate a number raised to a power, Matrix Exponentiation is used to calculate a matrix raised to a power efficiently.

Let us understand Matrix Exponentiation with the help of an example:


We can calculate matrix M^(N - 2) in logN time using Matrix Exponentiation. The idea is same as Binary Exponentiation:

When we are calculating (MN), we can have 3 possible positive values of N:

  • Case 1: If N = 0, whatever be the value of M, our result will be Identity Matrix I.
  • Case 2: If N is an even number, then instead of calculating (MN), we can calculate ((M2)N/2) and the result will be same.
  • Case 3: If N is an odd number, then instead of calculating (MN), we can calculate (M * (M(N – 1)/2)2).

Use Cases of Matrix Exponentiation:

Finding nth Fibonacci Number:

The recurrence relation for Fibonacci Sequence is F(n) = F(n - 1) + F(n - 2) starting with F(0) = 0 and F(1) = 1.


Below is the implementation of above idea:

C++
// C++ Program to find the Nth fibonacci number using // Matrix Exponentiation #include <bits/stdc++.h> using namespace std; int MOD = 1e9 + 7; // function to multiply two 2x2 Matrices void multiply(vector<vector<long long> >& A,  vector<vector<long long> >& B) {  // Matrix to store the result  vector<vector<long long> > C(2, vector<long long>(2));  // Matrix Multiply  C[0][0] = (A[0][0] * B[0][0] + A[0][1] * B[1][0]) % MOD;  C[0][1] = (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % MOD;  C[1][0] = (A[1][0] * B[0][0] + A[1][1] * B[1][0]) % MOD;  C[1][1] = (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % MOD;  // Copy the result back to the first matrix  A[0][0] = C[0][0];  A[0][1] = C[0][1];  A[1][0] = C[1][0];  A[1][1] = C[1][1]; } // Function to find (Matrix M ^ expo) vector<vector<long long> > power(vector<vector<long long> > M, int expo) {  // Initialize result with identity matrix  vector<vector<long long> > ans = { { 1, 0 }, { 0, 1 } };  // Fast Exponentiation  while (expo) {  if (expo & 1)  multiply(ans, M);  multiply(M, M);  expo >>= 1;  }  return ans; } // function to find the nth fibonacci number int nthFibonacci(int n) {  // base case  if (n == 0 || n == 1)  return 1;  vector<vector<long long> > M = { { 1, 1 }, { 1, 0 } };    // Matrix F = {{f(0), 0}, {f(1), 0}}, where f(0) and  // f(1) are first two terms of fibonacci sequence  vector<vector<long long> > F = { { 1, 0 }, { 0, 0 } };  // Multiply matrix M (n - 1) times  vector<vector<long long> > res = power(M, n - 1);  // Multiply Resultant with Matrix F  multiply(res, F);  return res[0][0] % MOD; } int main() {  // Sample Input  int n = 3;  // Print nth fibonacci number  cout << nthFibonacci(n) << endl; } 
Java
// Java Program to find the Nth fibonacci number using  // Matrix Exponentiation import java.util.*; public class GFG {  static final int MOD = 1000000007;  // Function to multiply two 2x2 matrices  public static void multiply(long[][] A, long[][] B) {  // Matrix to store the result  long[][] C = new long[2][2];  // Matrix multiplication  C[0][0] = (A[0][0] * B[0][0] + A[0][1] * B[1][0]) % MOD;  C[0][1] = (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % MOD;  C[1][0] = (A[1][0] * B[0][0] + A[1][1] * B[1][0]) % MOD;  C[1][1] = (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % MOD;  // Copy the result back to the first matrix  for (int i = 0; i < 2; i++) {  for (int j = 0; j < 2; j++) {  A[i][j] = C[i][j];  }  }  }  // Function to find (Matrix M ^ expo)  public static long[][] power(long[][] M, int expo) {  // Initialize result with identity matrix  long[][] ans = { { 1, 0 }, { 0, 1 } };  // Fast exponentiation  while (expo > 0) {  if ((expo & 1) != 0) {  multiply(ans, M);  }  multiply(M, M);  expo >>= 1;  }  return ans;  }  // Function to find the nth Fibonacci number  public static int nthFibonacci(int n) {  // Base case  if (n == 0 || n == 1) {  return 1;  }  long[][] M = { { 1, 1 }, { 1, 0 } };  // F(0) = 1, F(1) = 1  long[][] F = { { 1, 0 }, { 0, 0 } };  // Multiply matrix M (n - 1) times  long[][] res = power(M, n - 1);  // Multiply resultant with matrix F  multiply(res, F);  return (int)((res[0][0]) % MOD);  }  public static void main(String[] args) {  // Sample input  int n = 3;  // Print nth Fibonacci number  System.out.println(nthFibonacci(n));  } } 
Python
# Python Program to find the Nth fibonacci number using # Matrix Exponentiation MOD = 10**9 + 7 # function to multiply two 2x2 Matrices def multiply(A, B): # Matrix to store the result C = [[0, 0], [0, 0]] # Matrix Multiply C[0][0] = (A[0][0] * B[0][0] + A[0][1] * B[1][0]) % MOD C[0][1] = (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % MOD C[1][0] = (A[1][0] * B[0][0] + A[1][1] * B[1][0]) % MOD C[1][1] = (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % MOD # Copy the result back to the first matrix A[0][0] = C[0][0] A[0][1] = C[0][1] A[1][0] = C[1][0] A[1][1] = C[1][1] # Function to find (Matrix M ^ expo) def power(M, expo): # Initialize result with identity matrix ans = [[1, 0], [0, 1]] # Fast Exponentiation while expo: if expo & 1: multiply(ans, M) multiply(M, M) expo >>= 1 return ans def nthFibonacci(n): # Base case if n == 0 or n == 1: return 1 M = [[1, 1], [1, 0]] # F(0) = 0, F(1) = 1 F = [[1, 0], [0, 0]] # Multiply matrix M (n - 1) times res = power(M, n - 1) # Multiply Resultant with Matrix F multiply(res, F) return res[0][0] % MOD # Sample Input n = 3 # Print the nth fibonacci number print(nthFibonacci(n)) 
C#
// C# Program to find the Nth fibonacci number using  // Matrix Exponentiation using System; using System.Collections.Generic; public class GFG {  static int MOD = 1000000007;  // function to multiply two 2x2 Matrices  public static void Multiply(long[][] A, long[][] B)  {  // Matrix to store the result  long[][] C  = new long[2][] { new long[2], new long[2] };  // Matrix Multiply  C[0][0]  = (A[0][0] * B[0][0] + A[0][1] * B[1][0]) % MOD;  C[0][1]  = (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % MOD;  C[1][0]  = (A[1][0] * B[0][0] + A[1][1] * B[1][0]) % MOD;  C[1][1]  = (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % MOD;  // Copy the result back to the first matrix  A[0][0] = C[0][0];  A[0][1] = C[0][1];  A[1][0] = C[1][0];  A[1][1] = C[1][1];  }  // Function to find (Matrix M ^ expo)  public static long[][] Power(long[][] M, int expo)  {  // Initialize result with identity matrix  long[][] ans  = new long[2][] { new long[] { 1, 0 },  new long[] { 0, 1 } };  // Fast Exponentiation  while (expo > 0) {  if ((expo & 1) > 0)  Multiply(ans, M);  Multiply(M, M);  expo >>= 1;  }  return ans;  }  // function to find the nth fibonacci number  public static int NthFibonacci(int n)  {  // base case  if (n == 0 || n == 1)  return 1;  long[][] M = new long[2][] { new long[] { 1, 1 },  new long[] { 1, 0 } };  // F(0) = 0, F(1) = 1  long[][] F = new long[2][] { new long[] { 1, 0 },  new long[] { 0, 0 } };  // Multiply matrix M (n - 1) times  long[][] res = Power(M, n - 1);  // Multiply Resultant with Matrix F  Multiply(res, F);  return (int)((res[0][0] % MOD));  }  public static void Main(string[] args)  {  // Sample Input  int n = 3;  // Print nth fibonacci number  Console.WriteLine(NthFibonacci(n));  } } 
JavaScript
const MOD = 1e9 + 7; // Function to multiply two 2x2 matrices function multiply(A, B) {  // Matrix to store the result  const C = [  [0, 0],  [0, 0]  ];  // Matrix Multiply  C[0][0] = (A[0][0] * B[0][0] + A[0][1] * B[1][0]) % MOD;  C[0][1] = (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % MOD;  C[1][0] = (A[1][0] * B[0][0] + A[1][1] * B[1][0]) % MOD;  C[1][1] = (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % MOD;  // Copy the result back to the first matrix  A[0][0] = C[0][0];  A[0][1] = C[0][1];  A[1][0] = C[1][0];  A[1][1] = C[1][1]; } // Function to find (Matrix M ^ expo) function power(M, expo) {  // Initialize result with identity matrix  const ans = [  [1, 0],  [0, 1]  ];  // Fast Exponentiation  while (expo) {  if (expo & 1) multiply(ans, M);  multiply(M, M);  expo >>= 1;  }  return ans; } // Function to find the nth fibonacci number function nthFibonacci(n) {  // Base case  if (n === 0 || n === 1) return 1;  const M = [  [1, 1],  [1, 0]  ];  // F(0) = 0, F(1) = 1  const F = [  [1, 0],  [0, 0]  ];  // Multiply matrix M (n - 1) times  const res = power(M, n - 1);  // Multiply Resultant with Matrix F  multiply(res, F);  return res[0][0] % MOD; } // Sample Input const n = 3; // Print nth fibonacci number console.log(nthFibonacci(n)); 

Output
2 

Time Complexity: O(logN), because fast exponentiation takes O(logN) time.
Auxiliary Space: O(1)

Finding nth Tribonacci Number:

The recurrence relation for Tribonacci Sequence is T(n) = T(n - 1) + T(n - 2) + T(n - 3) starting with T(0) = 0, T(1) = 1 and T(2) = 1.


Below is the implementation of above idea:

C++
// C++ Program to find the nth tribonacci number #include <bits/stdc++.h> using namespace std; // Function to multiply two 3x3 matrices void multiply(vector<vector<long long> >& A,  vector<vector<long long> >& B) {  // Matrix to store the result  vector<vector<long long> > C(3, vector<long long>(3));  for (int i = 0; i < 3; i++) {  for (int j = 0; j < 3; j++) {  for (int k = 0; k < 3; k++) {  C[i][j]  = (C[i][j] + ((A[i][k]) * (B[k][j])));  }  }  }  // Copy the result back to the first matrix  for (int i = 0; i < 3; i++) {  for (int j = 0; j < 3; j++) {  A[i][j] = C[i][j];  }  } } // Function to calculate (Matrix M) ^ expo vector<vector<long long> > power(vector<vector<long long> > M, int expo) {  // Initialize result with identity matrix  vector<vector<long long> > ans  = { { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } };  // Fast Exponentiation  while (expo) {  if (expo & 1)  multiply(ans, M);  multiply(M, M);  expo >>= 1;  }  return ans; } // function to return the Nth tribonacci number long long tribonacci(int n) {  // base condition  if (n == 0 || n == 1)  return n;  // Matrix M to generate the next tribonacci number  vector<vector<long long> > M  = { { 1, 1, 1 }, { 1, 0, 0 }, { 0, 1, 0 } };  // Since first 3 number of tribonacci series are:  // trib(0) = 0  // trib(1) = 1  // trib(2) = 1  // F = {{trib(2), 0, 0}, {trib(1), 0, 0}, {trib(0), 0,  // 0}}  vector<vector<long long> > F  = { { 1, 0, 0 }, { 1, 0, 0 }, { 0, 0, 0 } };  vector<vector<long long> > res = power(M, n - 2);  multiply(res, F);  return res[0][0]; } int main() {  // Sample Input  int n = 4;  // Function call  cout << tribonacci(n);  return 0; } 
Java
// Java program to find nth tribonacci number import java.util.*; public class Main {  // Function to multiply two 3x3 matrices  static void multiply(long[][] A, long[][] B)  {  // Matrix to store the result  long[][] C = new long[3][3];  for (int i = 0; i < 3; i++) {  for (int j = 0; j < 3; j++) {  for (int k = 0; k < 3; k++) {  C[i][j] += (A[i][k] * B[k][j]);  }  }  }  // Copy the result back to the first matrix  for (int i = 0; i < 3; i++) {  for (int j = 0; j < 3; j++) {  A[i][j] = C[i][j];  }  }  }  // Function to calculate (Matrix M) ^ expo  static long[][] power(long[][] M, int expo)  {  // Initialize result with identity matrix  long[][] ans  = { { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } };  // Fast Exponentiation  while (expo > 0) {  if ((expo & 1) == 1)  multiply(ans, M);  multiply(M, M);  expo >>= 1;  }  return ans;  }  // function to return the Nth tribonacci number  static long tribonacci(int n)  {  // base condition  if (n == 0 || n == 1)  return n;  // Matrix M to generate the next tribonacci number  long[][] M  = { { 1, 1, 1 }, { 1, 0, 0 }, { 0, 1, 0 } };  // Since first 3 number of tribonacci series are:  // trib(0) = 0  // trib(1) = 1  // trib(2) = 1  // F = {{trib(2), 0, 0}, {trib(1), 0, 0}, {trib(0),  // 0, 0}}  long[][] F  = { { 1, 0, 0 }, { 1, 0, 0 }, { 0, 0, 0 } };  long[][] res = power(M, n - 2);  multiply(res, F);  return res[0][0];  }  public static void main(String[] args)  {  // Sample Input  int n = 4;  // Function call  System.out.println(tribonacci(n));  } } 
Python
# Python3 program to find nth tribonacci number # Function to multiply two 3x3 matrices def multiply(A, B): # Matrix to store the result C = [[0]*3 for _ in range(3)] for i in range(3): for j in range(3): for k in range(3): C[i][j] += A[i][k] * B[k][j] # Copy the result back to the first matrix A for i in range(3): for j in range(3): A[i][j] = C[i][j] # Function to calculate (Matrix M) ^ expo def power(M, expo): # Initialize result with identity matrix ans = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # Fast Exponentiation while expo > 0: if expo & 1: multiply(ans, M) multiply(M, M) expo >>= 1 return ans # Function to return the Nth tribonacci number def tribonacci(n): # Base condition if n == 0 or n == 1: return n # Matrix M to generate the next tribonacci number M = [[1, 1, 1], [1, 0, 0], [0, 1, 0]] # Since first 3 numbers of tribonacci series are: # trib(0) = 0 # trib(1) = 1 # trib(2) = 1 # F = [[trib(2), 0, 0], [trib(1), 0, 0], [trib(0), 0, 0]] F = [[1, 0, 0], [1, 0, 0], [0, 0, 0]] res = power(M, n - 2) multiply(res, F) return res[0][0] # Sample Input n = 4 # Function call print(tribonacci(n)) 
C#
// C# program to find nth tribonacci number using System; public class MainClass {  // Function to multiply two 3x3 matrices  static void Multiply(long[][] A, long[][] B)  {  // Matrix to store the result  long[][] C = new long[3][];  for (int i = 0; i < 3; i++) {  C[i] = new long[3];  for (int j = 0; j < 3; j++) {  for (int k = 0; k < 3; k++) {  C[i][j] += A[i][k] * B[k][j];  }  }  }  // Copy the result back to the first matrix A  for (int i = 0; i < 3; i++) {  for (int j = 0; j < 3; j++) {  A[i][j] = C[i][j];  }  }  }  // Function to calculate (Matrix M) ^ expo  static long[][] Power(long[][] M, int expo)  {  // Initialize result with identity matrix  long[][] ans = new long[3][];  for (int i = 0; i < 3; i++) {  ans[i] = new long[3];  ans[i][i] = 1; // Diagonal elements are 1  }  // Fast Exponentiation  while (expo > 0) {  if ((expo & 1) == 1)  Multiply(ans, M);  Multiply(M, M);  expo >>= 1;  }  return ans;  }  // Function to return the Nth tribonacci number  static long Tribonacci(int n)  {  // Base condition  if (n == 0 || n == 1)  return n;  // Matrix M to generate the next tribonacci number  long[][] M  = new long[][] { new long[] { 1, 1, 1 },  new long[] { 1, 0, 0 },  new long[] { 0, 1, 0 } };  // Since first 3 numbers of tribonacci series are:  // trib(0) = 0  // trib(1) = 1  // trib(2) = 1  // F = [[trib(2), 0, 0], [trib(1), 0, 0], [trib(0),  // 0, 0]]  long[][] F  = new long[][] { new long[] { 1, 0, 0 },  new long[] { 1, 0, 0 },  new long[] { 0, 0, 0 } };  long[][] res = Power(M, n - 2);  Multiply(res, F);  return res[0][0];  }  public static void Main(string[] args)  {  // Sample Input  int n = 4;  // Function call  Console.WriteLine(Tribonacci(n));  } } 
JavaScript
// JavaScript Program to find nth tribonacci number // Function to multiply two 3x3 matrices function multiply(A, B) {  // Matrix to store the result  let C = [  [0, 0, 0],  [0, 0, 0],  [0, 0, 0]  ];  for (let i = 0; i < 3; i++) {  for (let j = 0; j < 3; j++) {  for (let k = 0; k < 3; k++) {  C[i][j] += A[i][k] * B[k][j];  }  }  }  // Copy the result back to the first matrix A  for (let i = 0; i < 3; i++) {  for (let j = 0; j < 3; j++) {  A[i][j] = C[i][j];  }  } } // Function to calculate (Matrix M) ^ expo function power(M, expo) {  // Initialize result with identity matrix  let ans = [  [1, 0, 0],  [0, 1, 0],  [0, 0, 1]  ];  // Fast Exponentiation  while (expo > 0) {  if (expo & 1)  multiply(ans, M);  multiply(M, M);  expo >>= 1;  }  return ans; } // Function to return the Nth tribonacci number function tribonacci(n) {  // base condition  if (n === 0 || n === 1)  return n;  // Matrix M to generate the next tribonacci number  let M = [  [1, 1, 1],  [1, 0, 0],  [0, 1, 0]  ];  // Since first 3 numbers of tribonacci series are:  // trib(0) = 0  // trib(1) = 1  // trib(2) = 1  // F = [[trib(2), 0, 0], [trib(1), 0, 0], [trib(0), 0, 0]]  let F = [  [1, 0, 0],  [1, 0, 0],  [0, 0, 0]  ];  let res = power(M, n - 2);  multiply(res, F);  return res[0][0]; } // Main function function main() {  // Sample Input  let n = 4;  // Function call  console.log(tribonacci(n)); } // Call the main function main(); 

Output
4

Time Complexity: O(logN), because fast exponentiation takes O(logN) time.
Auxiliary Space: O(1)

Applications of Matrix Exponentiation:

Matrix Exponentiation has a variety of applications. Some of them are:

  • Any linear recurrence relation, such as the Fibonacci Sequence, Tribonacci Sequence or linear homogeneous recurrence relations with constant coefficients, can be solved using matrix exponentiation.
  • The RSA encryption algorithm involves exponentiation of large numbers, which can be efficiently handled using matrix exponentiation techniques.
  • Dynamic programming problems, especially those involving linear recurrence relations, can be optimized using matrix exponentiation to reduce time complexity.
  • Matrix exponentiation is used in number theory problems involving modular arithmetic, such as finding large powers of numbers modulo some value efficiently.

Advantages of Matrix Exponentiation:

Advantages of using Matrix Exponentiation are:

  • Matrix Exponentiation helps in finding Nth term of linear recurrence relations like Fibonacci or Tribonacci Series in log(N) time. This makes it much faster for large values of N.
  • It requires O(1) space if we use the iterative approach as it requires constant amount of extra space.
  • Large numbers can be handled without integer overflow using modulo operations.

Disadvantages of Matrix Exponentiation:

Disadvantages of using Matrix Exponentiation are:

  • Matrix Exponentiation is more complex than other iterative or recursive methods. Hence, it is harder to debug.
  • Initial conditions should be handled carefully to avoid incorrect results.

Next Article

Similar Reads