tensor puzzles

my solutions to problems from srush/Tensor-Puzzles

Index

  1. Tensor Puzzle Solutions
    1. Puzzle 1 - ones
    2. Puzzle 2 - sum
    3. Puzzle 3 - outer
    4. Puzzle 4 - diag
    5. Puzzle 5 - eye
    6. Puzzle 6 - triu
    7. Puzzle 7 - cumsum
    8. Puzzle 8 - diff
    9. Puzzle 9 - vstack
    10. Puzzle 10 - roll
    11. Puzzle 11 - flip
    12. Puzzle 12 - compress
    13. Puzzle 13 - pad_to
    14. Puzzle 14 - sequence_mask
    15. Puzzle 15 - bincount
    16. Puzzle 16 - scatter_add
    17. Puzzle 17 - flatten
    18. Puzzle 18 - linspace
    19. Puzzle 19 - heaviside
    20. Puzzle 20 - repeat (1d)
    21. Puzzle 21 - bucketize

Tensor Puzzle Solutions

The rules are simple:

  1. Can use tensor broadcasting

  2. Each puzzle needs to be solved in 1 line (<80 columns) of code.

  3. You are allowed @, arithmetic, comparison, shape, any indexing (e.g. a[:j], a[:, None], a[arange(10)]), and previous puzzle functions.

  4. You are not allowed anything else. No view, sum, take, squeeze, tensor.

  5. The following functions are implemented for you:

    • arange to replace a for-loop
    python
    def arange(i: int):
        "Use this function to replace a for-loop."
        return torch.tensor(range(i))
    
    • where to replace an if-statement
    python
    def where(q, a, b):
        "Use this function to replace an if-statement."
        return (q * a) + (~q) * b
    

Puzzle 1 - ones

Compute ones - the vector of all ones.

python
def ones(i: int) -> TT["i"]:
    return where(arange(i) > -1, 1, 0)

Puzzle 2 - sum

Compute sum - the sum of a vector.

python
def sum(a: TT) -> TT[1]:
    return ones(a.shape[0]) @ a[:, None]

Puzzle 3 - outer

Compute outer - the outer product of two vectors.

python
def outer(a: TT["i"],b: TT["j"]) -> TT["i", "j"]:
    return a[:, None] @ b[None, :]

Puzzle 4 - diag

Compute diag - the diagonal vector of a square matrix.

python
def diag(a: TT["i", "i"]) -> TT["i"]:
    return a[arange(a.shape[0]), arange(a.shape[0])]

Puzzle 5 - eye

Compute eye - the identity matrix.

python
def eye(j: int) -> TT["i", "i"]:
    return where(arange(j)[:, None] == arange(j)[None, :], 1, 0)

Puzzle 6 - triu

Compute triu - the upper triangular matrix.

python
def triu(j: int) -> TT["j", "j"]:
    return where(arange(j)[:, None] <= arange(j)[None, :], 1, 0)

Puzzle 7 - cumsum

Compute cumsum - the cumulative sum.

python
def cumsum(a: TT["i"]) -> TT["i"]:
    return (a[None, :] @ triu(a.shape[0]))[0]

Puzzle 8 - diff

Compute diff - the running difference.

python
def diff(a: TT["i"], i: int) -> TT["i"]:
    return where(arange(a.shape[0]) == 0, a[0], a - a[arange(i) - 1])

Puzzle 9 - vstack

Compute vstack - the matrix of two vectors

python
def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
    return where(arange(2)[:, None] == 0, a[None, :], b[None, :])

Puzzle 10 - roll

Compute roll - the vector shifted 1 circular position.

python
def roll(a: TT["i"], i: int) -> TT["i"]:
    return a[(arange(i) + 1) % a.shape[0]]

Puzzle 11 - flip

Compute flip - the reversed vector

python
def flip(a: TT["i"], i: int) -> TT["i"]:
    return a[i - arange(i) - 1]

Puzzle 12 - compress

Compute compress - keep only masked entries (left-aligned).

python
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
    return v @ where(g[:, None], arange(i) == (cumsum(1*g) - 1)[:, None], 0)

Puzzle 13 - pad_to

Compute pad_to - eliminate or add 0s to change size of vector.

python
def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
    return compress(ones(i) <= j, a * (where(arange(i)<j, 1, 0)), j)

Puzzle 14 - sequence_mask

Compute sequence_mask - pad out to length per batch.

python
def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
    return (arange(values.shape[1])[None, :] < length[:, None]) * values

Puzzle 15 - bincount

Compute bincount - count number of times an entry was seen.

python
def bincount(a: TT["i"], j: int) -> TT["j"]:
    return ones(a.shape[0]) @ where(a[:, None] == arange(j)[None, :], 1, 0)

Puzzle 16 - scatter_add

Compute scatter_add - add together values that link to the same location.

python
def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
    return ones(link.shape[0]) @ (values[:, None] * where(link[:, None] == arange(j)[None, :], 1, 0))

Puzzle 17 - flatten

Compute flatten

python
def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
    return a[arange(i*j)// j ] @ eye(j)[:, arange(i*j) % j] * eye(i*j) @ ones(i*j)

Puzzle 18 - linspace

Compute linspace

python
def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
    return i + (j - i) * (arange(n) / where(n > 1, n - 1, 1))

Puzzle 19 - heaviside

Compute heaviside

python
def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
    return where(a == 0, b, a>0)

Puzzle 20 - repeat (1d)

Compute repeat

python
def repeat(a: TT["i"], d: TT[1]) -> TT["d", "i"]:
    return ones(d)[:, None] @ a[None, :]

Puzzle 21 - bucketize

Compute bucketize

python
def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
    return ones(boundaries.shape[0])[None, :] @ where(v[None, :] >= boundaries[:, None], 1, 0)