Shape typing in Python
While I was looking the other way,
Python got advanced static types!
Here’s matrix multiplication,
describing the input shapes and its output shape:
def mat_mul[
N, K, M
](
m1: Mat[N, M],
m2: Mat[M, K],
) -> Mat[N, K]:
return m1 @ m2
There’s a lot going on here!
In traditional Python, we’d write:
def mat_mul(m1, m2):
return m1 @ m2
Then if we used the wrong shapes,
we’d get a runtime error,
like this:
>>> m1 @ m2
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: ...
Read more at jameshfisher.com