Markov Chains are fun

Mon 25 February 2019

I am part of the course team for the Casimir programming course. Each year we take 50 students through a software carpentry-style intensive course in Python and scientific programming over the course of a week. The capstone is a project lasting a couple of days where the students put into practice all that they've learned in the course.

Coming up with cool projects is a chore, however I recently read a blog post about using Markov Chain Monte Carlo for decrypting substitution ciphers. This meshes well with the other themes in the course, and on the first day there is a small exercise that uses some statistical analysis for decrypting substitution ciphers, however it is not very automatic. The blog post references this 2010 paper from some Masters students at the University of Toronto, which I used as inspiration.

The General Idea

We have some text that we know has been encrypted using a substitution cipher, however we do not know the encryption key that has been used.

The space that we are searching is the space of encryption keys. You can think of a key as a bijective map from the alphabet to itself, e.g. A → D, B → R, .... The associated decryption key is just the inverse of this map. For a given decryption key we can attempt to decrypt the ciphertext. We will get some cleartext that may or may not be correct. What is clear is that if more entries in the decryption key are correct, the closer the cleartext will be to the right answer. We can analyze the frequency of pairs of letters in the cleartext and compare it to the frequency in some reference text. A higher number of matches will make the cleartext score higher. If we use the ratio of scores of different pairs of letters as our transition probability (properly normalized) then we can use a Markov Chain to sample the space of keys and (if implemented well!) converge to the true key.

Step 1: Get a reference text

We'll use a large corpus of English text as our reference. Luckily Project Guthenberg has a good number of English texts. For this example we choose War and Peace.

In [1]:
from urllib.parse import urlparse
from itertools import product
from string import ascii_lowercase, printable, punctuation
from itertools import groupby, chain

import requests


def is_url(maybe_url):
    parsed_url = urlparse(maybe_url)
    return parsed_url.scheme and parsed_url.netloc


WORD_MARKER = ' '
ALPHABET = ascii_lowercase
ALLOWED_CHARS = frozenset(ALPHABET + WORD_MARKER)
EXCLUDED_CHARS = frozenset(printable) - ALLOWED_CHARS
ALPHA_TO_INDEX = {a: i for i, a in enumerate(ALPHABET)}


def normalize_text(text):
    """Normalize a text using certain rules
        
    The normalization rules are the following:
        + all alphabetic characters are converted to lowercase
        + all non-alphabetic characters are converted to an end-of-word marker character.
          We will only be analyzing the text on the level of the constituent
          words, not the grammar, so we only care about punctuation and whitespace
          because it indicates the start/end of a word.
    """
    text = text.lower()
    # normalize punctuation to whitespace. Probably incorrect for hyphenation,
    # but we hope that hyphenated words are rare. This also catches
    # (and ignores) non-ascii characters
    text = ((c if c in ALLOWED_CHARS else WORD_MARKER) for c in text)
    # remove duplicates of WORD_MARKER
    text = chain.from_iterable(c if c == WORD_MARKER else g for c, g in groupby(text))
    return ''.join(text)
    

# TODO: convert this to work on streams, for truly huge reference texts,
#       to avoid reading the whole reference text into memory at once
def get_reference_text(name):
    """Returns a normalized reference text as a string.
    
    See the documentation for 'normalize_text' for details of the normalization.
    
    Parameters
    ----------
    name : str
        The name of the text to fetch; either a path to a file or a URL.
        If a URL is provided, GETting the URL must return the text.
    """
    try:
        if is_url(name):
            text = requests.get(name).text
        else:
            with open(name) as file:
                text = file.read()        
    except Exception as error:
        msg = f'There was a problem fetching the text from "{name}"'
        raise ValueError(msg) from error
    
    return normalize_text(text)
In [2]:
war_and_peace = get_reference_text('http://www.gutenberg.org/files/2600/2600-0.txt')

Next we need a few utilities for counting bigrams in a text and constructing the matrix of probabilities for finding a letter in position $X+1$ given that a given letter is in position $X$. This is exactly the normalized matrix of bigram frequencies.

In [3]:
from collections import Counter
from operator import mul
from functools import reduce
from itertools import islice


def pairs(sequence):
    return zip(sequence, islice(sequence, 1))


def prod(iterable):
    return reduce(mul, iterable, 1)


def take(n, it):
    return islice(it, n)


def count_bigrams(text):
    "Return the bigrams in a text as a dict (char1, char2) → count."
    return Counter(pairs(text))


def construct_transitions(text):
    transitions = count_bigrams(text)
    for c in ALLOWED_CHARS:
        total = sum(transitions[c, p] for p in ALLOWED_CHARS)
        if total == 0:
            continue
        for p in ALLOWED_CHARS:
            transitions[c, p] /= total
    return transitions  
In [4]:
wnp_transitions = construct_transitions(war_and_peace)

Next we define some tools for working with encryption/decryption keys

In [5]:
import random
from contextlib import contextmanager


@contextmanager
def set_seed(seed=None):
    """A context manager that sets/resets the Python RNG seed on entry and exit.
    
    If the provided seed is 'None', then this context manager does nothing.
    """
    if seed is not None:
        rng_state = random.getstate()
        random.seed(seed)
    yield
    if seed is not None:
        random.setstate(rng_state)
In [6]:
from string import ascii_lowercase
from random import shuffle


def random_key(seed=None):
    """Return a random map *from* ciphertext symbols *to* cleartext symbols.
    
    Parameters
    ----------
    seed : int (optional)
        If provided, the Python random generator will be seeded with the provided
        value before generating the key, and restored to its previous state afterwards.
        This is useful for producing the same key twice.
    """
    with set_seed(seed):
        # 'shuffle' only operates in-place on lists
        shuffled = list(ALPHABET)
        shuffle(shuffled)

    return dict(zip(ALPHABET, shuffled))
In [7]:
def decrypt(ciphertext, key):
    """Decrypt a ciphertext using a substitution cipher with the provided key.
    
    Parameters
    ----------
    ciphertext : str
        The text to decrypt
    key : dict : str → str
        A map *from* ciphertext symbols *to* cleartext symbols.
        Any characters that appear in 'ciphertext' but do not appear in 'key'
        remain unchanged in the cleartext.
    """
    # XXX: If we're going to be calling this many times, we should
    #      consider making the output of 'maketrans' the canonical key format
    return ciphertext.translate(str.maketrans(key))


def encrypt(cleartext, key):
    """Encrypt a ciphertext using a substitution cipher with the provided key.
    
    Parameters
    ----------
    cleartext : str
        The text to encrypt
    key : dict : str → str
        A map *from* ciphertext symbols *to* cleartext symbols
        Any characters that appear in 'ciphertext' but do not appear in 'key'
        remain unchanged in the cleartext.
    """
    # Encryption is decryption with the key reversed
    key = {v: k for k, v in key.items()}
    return decrypt(cleartext, key)

And some utilities for constructing the "distance" between 2 keys.

In [8]:
def similarity(seq1, seq2):
    l = min(len(seq1), len(seq2))
    return sum(c1 == c2 for c1, c2 in zip(seq1, seq2)) / l


def distance(ciphertext, key1, key2):
    """Return the distance between 'key1' and 'key2'
    
    The distance is defined as the proportion of characters that are the same between the
    cleartexts obtained using 'key1' and 'key2'.
    """
    cleartext1 = decrypt(ciphertext, key1)
    cleartext2 = decrypt(ciphertext, key2)
    return 1 - similarity(cleartext1, cleartext2)
    

## From https://codereview.stackexchange.com/questions/172060/finding-the-minimum-number-of-swaps-to-sort-a-list
def cycle_decomposition(permutation):
    """Generate cycles in the cyclic decomposition of a permutation.

        >>> list(cycle_decomposition([7, 2, 9, 5, 0, 3, 6, 8, 1, 4]))
        [[0, 7, 8, 1, 2, 9, 4], [3, 5], [6]]

    """
    unvisited = set(permutation)
    while unvisited:
        j = i = unvisited.pop()
        cycle = [i]
        while True:
            j = permutation[j]
            if j == i:
                break
            cycle.append(j)
            unvisited.remove(j)
        yield cycle

        
def minimum_swaps(seq):
    """Return minimum swaps needed to sort the sequence.

        >>> minimum_swaps([])
        0
        >>> minimum_swaps([2, 1])
        1
        >>> minimum_swaps([4, 8, 1, 5, 9, 3, 6, 0, 7, 2])
        7

    """
    permutation = sorted(range(len(seq)), key=seq.__getitem__)
    return sum(len(cycle) - 1 for cycle in cycle_decomposition(permutation))

from random import choice from functools import lru_cache from math import log, inf, exp

def swapped(key): a, b = random.choices(ALPHABET, k=2) new = key.copy() new[a], new[b] = new[b], new[a] return new

def transition_probability(proposal_density, key_density): if key_density == 0: return 1 else: return max(proposal_density / key_density, 1)

def metropolis(ciphertext, transitions, start_key=None): ciphertext = normalize_text(ciphertext)

# Equation 2.4
# XXX: construct this using logarithms to avoid excessive rounding error
def log_pl(key):
    maybe_cleartext = decrypt(ciphertext, key)
    return sum(log(transitions[a, b]) if transitions[a, b] != 0 else -inf
               for a, b in pairs(maybe_cleartext))  

key = start_key or random_key()
yield key

while True:
    proposal = swapped(key)
    log_pl_proposal = log_pl(proposal)
    log_pl_key = log_pl(key)
    if log_pl_proposal > log_pl_key or log_pl_key == -inf:
        key = proposal
        best_key = key.copy()
    elif random.uniform(0, 1) < exp(log_pl_proposal - log_pl_key):
        key = proposal
    yield key

Finally we define the Metropolis algorithm

In [9]:
from random import choice
from functools import lru_cache


def swapped(key):
    a, b = random.choices(ALPHABET, k=2)
    new = key.copy()
    new[a], new[b] = new[b], new[a]
    return new


def transition_probability(proposal_density, key_density):
    if key_density == 0:
        return 1
    else:
        return max(proposal_density / key_density, 1)

    
def metropolis(ciphertext, transitions, start_key=None):
    ciphertext = normalize_text(ciphertext)
    
    # Equation 2.4
    # XXX: construct this using logarithms to avoid excessive rounding error
    def pl(key):
        maybe_cleartext = decrypt(ciphertext, key)
        return prod(transitions[a, b] for a, b in pairs(maybe_cleartext))  

    key = start_key or random_key()
    yield key

    while True:
        proposal = swapped(key)
        pl_proposal = pl(proposal)
        pl_key = pl(key)
        if pl_proposal > pl_key or pl_key == 0:
            key = proposal
            best_key = key.copy()
        elif random.uniform(0, 1) < pl_proposal / pl_key:
            key = proposal
        yield key



And run the algorithm on some example text to see if it works!

In [10]:
cleartext = normalize_text("""
Enter by the narrow gate, for wide is the gate and broad the road that leads to destruction
""")

ciphertext = cleartext  #encrypt(cleartext, random_key())

keys = metropolis(ciphertext, wnp_transitions, start_key=dict(zip(ALPHABET, ALPHABET)))

for i, key in enumerate(take(50000, keys)):
    if i % 2000 == 0:
        print(i, ':', decrypt(ciphertext, key))
0 :  enter by the narrow gate for wide is the gate and broad the road that leads to destruction 
2000 :  enter by the narrok wate mor kide is the wate and broad the road that veads to destrcution 
4000 :  enter gu the narrol bate for lide is the bate and groad the road that ceads to destrmption 
6000 :  enter py the narrof wate cor fide is the wate and proad the road that leads to destrmution 
8000 :  enter bl the narrof wate cor fide is the wate and broad the road that veads to destruption 
10000 :  enter pl the narrof wate bor fide is the wate and proad the road that keads to destrqution 
12000 :  enter fl the narrom wate gor mide is the wate and froad the road that ceads to destruption 
14000 :  enter by the narrof wate mor fide is the wate and broad the road that veads to destrqution 
16000 :  enter py the narrof wate mor fide is the wate and proad the road that beads to destrkution 
18000 :  enter py the narrov wate lor vide is the wate and proad the road that beads to destrgution 
20000 :  enter by the narrof wate mor fide is the wate and broad the road that veads to destrpution 
22000 :  enter by the narrof wate cor fide is the wate and broad the road that jeads to destrmution 
24000 :  enter bl the narrof wate gor fide is the wate and broad the road that peads to destrmytion 
26000 :  enter by the narrof wate cor fide ig the wate and broad the road that keadg to degtrmstion 
28000 :  enter by the narrol wate cor lide is the wate and broad the road that peads to destrqution 
30000 :  enter by the narrol wate for lide is the wate and broad the road that keads to destrmution 
32000 :  enter by the narrof pate cor fide is the pate and broad the road that keads to destrmution 
34000 :  enter by the narrof wate cor fide is the wate and broad the road that meads to destrlution 
36000 :  enter py the narrof wate cor fide is the wate and proad the road that geads to destrmution 
38000 :  enter py the narrom wate lor mide is the wate and proad the road that beads to destrcution 
40000 :  enter by the narrof wate vor fide is the wate and broad the road that geads to destrmption 
42000 :  enter by the narrom gate cor mide is the gate and broad the road that weads to destrkution 
44000 :  entel by the nallof wate jol frde rs the wate and bload the load that peads to destlkitron 
46000 :  enter by the narrow mate lor wide is the mate and broad the road that veads to destrgution 
48000 :  enter by the narrof wate cor fide is the wate and broad the road that leads to destruption 
CPU times: user 4.31 s, sys: 15.5 ms, total: 4.33 s
Wall time: 4.32 s
In [30]:
from itertools import tee

cleartext = normalize_text("""
Enter by the narrow gate, for wide is the gate and broad the road that leads to destruction.
""")

solution = dict(zip(ALPHABET, ALPHABET))  #random_key()

ciphertext = cleartext  #encrypt(cleartext, solution)

keys = metropolis(ciphertext, wnp_transitions, start_key=dict(zip(ALPHABET, ALPHABET)))

distances = [distance(ciphertext, k, solution) for k in take(20000, keys)]
In [31]:
import matplotlib.pyplot as plt

plt.plot(distances)
Out[31]:
[<matplotlib.lines.Line2D at 0x7f3e6186fa90>]

Closing remarks

The Markov chain seems to get stuck at some minimum distance from the true key. It's not 100% clear to me why this is the case; if anyone has any insights, drop me an email!