Skip to content
107 changes: 107 additions & 0 deletions blockchain/merkle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Merkle Tree Construction and Verification

This module implements the construction of a Merkle Tree and
verification of inclusion proofs for blockchain data integrity.

Each leaf is a SHA-256 hash of a transaction, and internal nodes are
computed by hashing the concatenation of their child nodes.

References:
https://en.wikipedia.org/wiki/Merkle_tree
"""

import hashlib


def sha256(data: str) -> str:
"""
Compute the SHA-256 hash of the given string.

Args:
data (str): Input string.

Returns:
str: Hexadecimal SHA-256 hash of the input.

Example:
>>> sha256("abc")
'ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad'
"""
return hashlib.sha256(data.encode()).hexdigest()


def build_merkle_tree(leaves: list[str]) -> list[list[str]]:
"""
Build a Merkle Tree from the given leaf nodes.

Args:
leaves: List of data strings (transactions).

Returns:
A list of lists representing tree levels,
with the last level containing the Merkle root.

>>> len(build_merkle_tree(["a", "b", "c", "d"])[-1][0])
64
"""
if not leaves:
raise ValueError("Leaf list cannot be empty.")

current_level = [sha256(x) for x in leaves]
tree = [current_level]

while len(current_level) > 1:
next_level = []
for i in range(0, len(current_level), 2):
left = current_level[i]
right = current_level[i + 1] if i + 1 < len(current_level) else left
next_level.append(sha256(left + right))
current_level = next_level
tree.append(current_level)

return tree


def merkle_root(leaves: list[str]) -> str:
"""
Return the Merkle root hash for a given list of data.

>>> r = merkle_root(["tx1", "tx2", "tx3"])
>>> isinstance(r, str)
True
"""
return build_merkle_tree(leaves)[-1][0]


def verify_proof(leaf: str, proof: list[str], root: str) -> bool:
"""
Verify inclusion of a leaf using a Merkle proof.

Args:
leaf: Original data string.
proof: List of sibling hashes up the path.
root: Expected Merkle root hash.

Returns:
True if proof is valid, else False.

>>> data = ["a", "b", "c", "d"]
>>> tree = build_merkle_tree(data)
>>> root = tree[-1][0]
>>> leaf = "a"
>>> proof = [sha256("b"), sha256(sha256("c") + sha256("d"))]
>>> verify_proof(leaf, proof, root)
True
"""
computed_hash = sha256(leaf)
for sibling in proof:
combined = sha256(computed_hash + sibling)
computed_hash = combined
return computed_hash == root


if __name__ == "__main__":
import doctest

doctest.testmod()
76 changes: 76 additions & 0 deletions quantum/simons_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Simon's Algorithm (Classical Simulation)

Simon's algorithm finds a hidden bitstring s such that
f(input_bits) = f(other_bits) if and only if input_bits XOR other_bits = s.

Here we simulate the mapping behavior classically to
illustrate how the hidden period can be discovered by
analyzing collisions in f(input_bits).

References:
https://en.wikipedia.org/wiki/Simon's_problem
"""

from collections.abc import Callable
from itertools import product


def xor_bits(bits1: list[int], bits2: list[int]) -> list[int]:
"""
Return the bitwise XOR of two equal-length bit lists.

>>> xor_bits([1, 0, 1], [1, 1, 0])
[0, 1, 1]
"""
if len(bits1) != len(bits2):
raise ValueError("Bit lists must be of equal length.")
return [x ^ y for x, y in zip(bits1, bits2)]


def simons_algorithm(f: Callable[[list[int]], list[int]], num_bits: int) -> list[int]:
"""
Simulate Simon's algorithm classically to find the hidden bitstring s.

Args:
f: A function mapping n-bit input to n-bit output.
num_bits: Number of bits in the input.

Returns:
The hidden bitstring s as a list of bits.

>>> # Example with hidden bitstring s = [1, 0, 1]
>>> s = [1, 0, 1]
>>> def f(input_bits):
... mapping = {
... (0,0,0): (1,1,0),
... (1,0,1): (1,1,0),
... (0,0,1): (0,1,1),
... (1,0,0): (0,1,1),
... (0,1,0): (1,0,1),
... (1,1,1): (1,0,1),
... (0,1,1): (0,0,0),
... (1,1,0): (0,0,0),
... }
... return mapping[tuple(input_bits)]
>>> simons_algorithm(f, 3)
[1, 0, 1]
"""
mapping: dict[tuple[int, ...], tuple[int, ...]] = {}
inputs = list(product([0, 1], repeat=num_bits))

for bits in inputs:
fx = tuple(f(list(bits)))
if fx in mapping:
prev_bits = mapping[fx]
return xor_bits(list(bits), list(prev_bits))
mapping[fx] = bits

# If no collision found, function might be constant
return [0] * num_bits


if __name__ == "__main__":
import doctest

doctest.testmod()