Index
- Tensor Puzzle Solutions
- Puzzle 1 - ones
- Puzzle 2 - sum
- Puzzle 3 - outer
- Puzzle 4 - diag
- Puzzle 5 - eye
- Puzzle 6 - triu
- Puzzle 7 - cumsum
- Puzzle 8 - diff
- Puzzle 9 - vstack
- Puzzle 10 - roll
- Puzzle 11 - flip
- Puzzle 12 - compress
- Puzzle 13 - pad_to
- Puzzle 14 - sequence_mask
- Puzzle 15 - bincount
- Puzzle 16 - scatter_add
- Puzzle 17 - flatten
- Puzzle 18 - linspace
- Puzzle 19 - heaviside
- Puzzle 20 - repeat (1d)
- Puzzle 21 - bucketize
Tensor Puzzle Solutions
The rules are simple:
-
Can use tensor broadcasting
-
Each puzzle needs to be solved in 1 line (<80 columns) of code.
-
You are allowed @, arithmetic, comparison, shape, any indexing (e.g.
a[:j]
,a[:, None]
,a[arange(10)]
), and previous puzzle functions. -
You are not allowed anything else. No view, sum, take, squeeze, tensor.
-
The following functions are implemented for you:
- arange to replace a for-loop
pythondef arange(i: int): "Use this function to replace a for-loop." return torch.tensor(range(i))
- where to replace an if-statement
pythondef where(q, a, b): "Use this function to replace an if-statement." return (q * a) + (~q) * b
Puzzle 1 - ones
Compute ones - the vector of all ones.
def ones(i: int) -> TT["i"]:
return where(arange(i) > -1, 1, 0)
Puzzle 2 - sum
Compute sum - the sum of a vector.
def sum(a: TT) -> TT[1]:
return ones(a.shape[0]) @ a[:, None]
Puzzle 3 - outer
Compute outer - the outer product of two vectors.
def outer(a: TT["i"],b: TT["j"]) -> TT["i", "j"]:
return a[:, None] @ b[None, :]
Puzzle 4 - diag
Compute diag - the diagonal vector of a square matrix.
def diag(a: TT["i", "i"]) -> TT["i"]:
return a[arange(a.shape[0]), arange(a.shape[0])]
Puzzle 5 - eye
Compute eye - the identity matrix.
def eye(j: int) -> TT["i", "i"]:
return where(arange(j)[:, None] == arange(j)[None, :], 1, 0)
Puzzle 6 - triu
Compute triu - the upper triangular matrix.
def triu(j: int) -> TT["j", "j"]:
return where(arange(j)[:, None] <= arange(j)[None, :], 1, 0)
Puzzle 7 - cumsum
Compute cumsum - the cumulative sum.
def cumsum(a: TT["i"]) -> TT["i"]:
return (a[None, :] @ triu(a.shape[0]))[0]
Puzzle 8 - diff
Compute diff - the running difference.
def diff(a: TT["i"], i: int) -> TT["i"]:
return where(arange(a.shape[0]) == 0, a[0], a - a[arange(i) - 1])
Puzzle 9 - vstack
Compute vstack - the matrix of two vectors
def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
return where(arange(2)[:, None] == 0, a[None, :], b[None, :])
Puzzle 10 - roll
Compute roll - the vector shifted 1 circular position.
def roll(a: TT["i"], i: int) -> TT["i"]:
return a[(arange(i) + 1) % a.shape[0]]
Puzzle 11 - flip
Compute flip - the reversed vector
def flip(a: TT["i"], i: int) -> TT["i"]:
return a[i - arange(i) - 1]
Puzzle 12 - compress
Compute compress - keep only masked entries (left-aligned).
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
return v @ where(g[:, None], arange(i) == (cumsum(1*g) - 1)[:, None], 0)
Puzzle 13 - pad_to
Compute pad_to - eliminate or add 0s to change size of vector.
def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
return compress(ones(i) <= j, a * (where(arange(i)<j, 1, 0)), j)
Puzzle 14 - sequence_mask
Compute sequence_mask - pad out to length per batch.
def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
return (arange(values.shape[1])[None, :] < length[:, None]) * values
Puzzle 15 - bincount
Compute bincount - count number of times an entry was seen.
def bincount(a: TT["i"], j: int) -> TT["j"]:
return ones(a.shape[0]) @ where(a[:, None] == arange(j)[None, :], 1, 0)
Puzzle 16 - scatter_add
Compute scatter_add - add together values that link to the same location.
def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
return ones(link.shape[0]) @ (values[:, None] * where(link[:, None] == arange(j)[None, :], 1, 0))
Puzzle 17 - flatten
Compute flatten
def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
return a[arange(i*j)// j ] @ eye(j)[:, arange(i*j) % j] * eye(i*j) @ ones(i*j)
Puzzle 18 - linspace
Compute linspace
def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
return i + (j - i) * (arange(n) / where(n > 1, n - 1, 1))
Puzzle 19 - heaviside
Compute heaviside
def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
return where(a == 0, b, a>0)
Puzzle 20 - repeat (1d)
Compute repeat
def repeat(a: TT["i"], d: TT[1]) -> TT["d", "i"]:
return ones(d)[:, None] @ a[None, :]
Puzzle 21 - bucketize
Compute bucketize
def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
return ones(boundaries.shape[0])[None, :] @ where(v[None, :] >= boundaries[:, None], 1, 0)