Rubik's cube simulation in 15 lines of numpy

The goal of this project was to implement the entire logic of a rubik's cube using the base funcitonnality of numpy: - reshaping - multi-indexing - indices

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon

The core logic

This create a permutation inside a tensor of shape $5 \ times N \times N \times N \ times 3 \times 2$

The idea is that the five parameters (ix, iy, iz, d, a) contains exactly the information about the location of a tiny cube: at cell (ix, iy, iz), oriented in direction d, and a indicates the way it points (up or down, right or left ...).

Then, we can define a huge array, with one element for each position. Here comes the magic sauce: A configuration of the rubik's cube is just a particular permutation of this huge array !

arr = lambda shape, *x: np.array(x,dtype=int).reshape(shape)
cho = lambda *s: np.random.choice(*s,replace=False)
i3  = arr((3,),0,1,2)

solved = lambda n: np.indices((n,n,n,3,2))

def move(n, i, d, d1, d2):
    # create the random permutation
    s = solved(n)
    result = s
    idx = [slice(n)]*3
    idx[d] = i
    result[d1, *idx] =           s[d2, *idx]
    result[d2, *idx] =       n-1-s[d1, *idx]
    result[-2, *idx, [d1, d2]] = s[-2,  *idx, [d2, d1]]
    result[-1, *idx, d1] =     1-s[-1, *idx, d2]
    return result

Display

Then, we need a code to display our cube.

COLORS = arr((3,2,3),*"002010220222200210")/2
SQUARE = arr((4, 3), *"000010110100")

def display(ax, cube):
    N = cube.shape[1]
    patches = []
    ax.clear()
    ax.set_axis_off()
    ax.set_xlim(-N*2, N*2)
    ax.set_ylim(-N*3, N*3)
    for d in i3:
        for p in np.ndindex(N,N,N):
            if p[d]==0:
                s=SQUARE.copy()
                s[:, [2,d]]=s[:, [d,2]]
                x,y,z = (s+p).T
                losange=arr((2, 4), x-z,x-2*y+z)
                c=COLORS[*cube[-2:, *p, d, 0]]
                polygon = Polygon(losange.T, edgecolor="0", facecolor=c)
                patches.append(ax.add_patch(polygon))
    return patches
r0 = solved(2)

fig, axes = plt.subplots(2, 3, figsize=(9,6))

for k, row in zip([-1, 1], axes):
    for i, ax in enumerate(row):
        m = move(2, 0, i, (i+k+3)%3, (i-k+3)%3)
        r = r0[:, *m]
    
        display(ax, r)
plt.show()
Generated plot
import matplotlib.animation
plt.rcParams["animation.html"] = "jshtml"
plt.ioff()
r = solved(4)
fig, ax = plt.subplots()
display(ax, r)

def update(t):
    global r
    m = move(4, cho(i3), *cho(i3, 3))
    r = r[:, *m]
    return display(ax, r)

matplotlib.animation.FuncAnimation(fig, update, frames=20, interval=500)