136 lines
4.6 KiB
Python
136 lines
4.6 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# Copyright (c) 2026 by David Boetius
|
|||
|
|
# Licensed under the MIT License.
|
|||
|
|
"""Generate grad tests that expose known bugs in minijax's gradient logic.
|
|||
|
|
|
|||
|
|
Bug 1 – unbroadcast is incomplete (grad.py:73-75):
|
|||
|
|
unbroadcast only removes extra leading dimensions but does not reduce
|
|||
|
|
along size-1 broadcast dimensions. Tests: mul_broadcast, add_broadcast
|
|||
|
|
|
|||
|
|
Bug 2 – dot VJP is wrong for 1D inputs (grad.py:90):
|
|||
|
|
transpose(x) @ t computes an inner product instead of an outer product
|
|||
|
|
when both x and t are 1D vectors. Test: dot_1d
|
|||
|
|
|
|||
|
|
Expected gradients are computed analytically with numpy so they are
|
|||
|
|
independent of the buggy _backwards implementation.
|
|||
|
|
|
|||
|
|
Run from the repository root:
|
|||
|
|
python tests/milestone1/base/unit/grad/generate_bugfix_grad_tests.py
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
from minijax import core
|
|||
|
|
from minijax.compute_graph import make_graph
|
|||
|
|
from minijax.eval import Array
|
|||
|
|
from minijax.serialize import dump
|
|||
|
|
|
|||
|
|
SCRIPT_DIR = Path(__file__).parent
|
|||
|
|
REPO_ROOT = SCRIPT_DIR.parents[4]
|
|||
|
|
|
|||
|
|
DEFAULT_TOLERANCE = 1e-4
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_test(name, fn, inputs_data, expected_grads, *, tolerance=DEFAULT_TOLERANCE):
|
|||
|
|
"""Write a grad test whose expected gradients are supplied externally."""
|
|||
|
|
test_dir = SCRIPT_DIR / name
|
|||
|
|
resources_dir = test_dir / "resources"
|
|||
|
|
resources_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
arrays = [Array(d) for d in inputs_data]
|
|||
|
|
graph = make_graph(fn)(*arrays)
|
|||
|
|
|
|||
|
|
network_file = f"{name}_network.mininn"
|
|||
|
|
dump(graph, resources_dir / network_file)
|
|||
|
|
|
|||
|
|
input_files = []
|
|||
|
|
for i, d in enumerate(inputs_data):
|
|||
|
|
fname = f"input_{i}.bin"
|
|||
|
|
np.asarray(d, dtype=np.float64).tofile(resources_dir / fname)
|
|||
|
|
input_files.append(f"resources/{fname}")
|
|||
|
|
|
|||
|
|
expected_files = []
|
|||
|
|
for i, grad in enumerate(expected_grads):
|
|||
|
|
fname = f"expected_grad_{i}.bin"
|
|||
|
|
np.asarray(grad, dtype=np.float64).tofile(test_dir / fname)
|
|||
|
|
expected_files.append(fname)
|
|||
|
|
|
|||
|
|
config = {
|
|||
|
|
"command": "grad",
|
|||
|
|
"network": f"resources/{network_file}",
|
|||
|
|
"inputs": input_files,
|
|||
|
|
"expected_outputs": expected_files,
|
|||
|
|
}
|
|||
|
|
if tolerance != DEFAULT_TOLERANCE:
|
|||
|
|
config["tolerance"] = tolerance
|
|||
|
|
(test_dir / "test.json").write_text(json.dumps(config, indent=4) + "\n")
|
|||
|
|
print(f" created {test_dir.relative_to(REPO_ROOT)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Bug 1: unbroadcast incomplete — mul with broadcasting
|
|||
|
|
# f(x, y) = x * y, x shape (3, 4), y shape (1, 4)
|
|||
|
|
# d(sum f)/dx = broadcast(y, (3,4))
|
|||
|
|
# d(sum f)/dy = x.sum(axis=0, keepdims=True) ← requires reducing
|
|||
|
|
# along the size-1 broadcast axis, which unbroadcast misses
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
rng = np.random.default_rng(100)
|
|||
|
|
x_data = rng.standard_normal((3, 4))
|
|||
|
|
y_data = rng.standard_normal((1, 4))
|
|||
|
|
|
|||
|
|
create_test(
|
|||
|
|
"mul_broadcast",
|
|||
|
|
lambda x, y: core.mul(x, y),
|
|||
|
|
[x_data, y_data],
|
|||
|
|
expected_grads=[
|
|||
|
|
np.broadcast_to(y_data, (3, 4)), # dL/dx
|
|||
|
|
x_data.sum(axis=0, keepdims=True), # dL/dy
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Same bug, simpler case: add with broadcasting
|
|||
|
|
# f(x, y) = x + y, x shape (3, 4), y shape (1, 4)
|
|||
|
|
# d(sum f)/dx = ones(3, 4)
|
|||
|
|
# d(sum f)/dy = ones(1, 4) * 3 (sum of ones along broadcast axis)
|
|||
|
|
create_test(
|
|||
|
|
"add_broadcast",
|
|||
|
|
lambda x, y: core.add(x, y),
|
|||
|
|
[x_data, y_data],
|
|||
|
|
expected_grads=[
|
|||
|
|
np.ones((3, 4)), # dL/dx
|
|||
|
|
np.full((1, 4), 3.0), # dL/dy
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Bug 2: dot VJP wrong for 1D inputs
|
|||
|
|
# f(x, y) = dot(x, y), x shape (8,), y shape (8, 5)
|
|||
|
|
# output shape (5,); out_tangent = ones(5,)
|
|||
|
|
# d(sum f)/dx_i = sum_j y_{ij} → y.sum(axis=1), shape (8,)
|
|||
|
|
# d(sum f)/dy_{ij} = x_i → x[:, None] * ones(1, 5), shape (8, 5)
|
|||
|
|
# The buggy VJP computes transpose(x) @ t = dot((8,), (5,)) which
|
|||
|
|
# raises an error (shape mismatch) or gives a scalar (if sizes equal).
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
x_dot = rng.standard_normal((8,))
|
|||
|
|
y_dot = rng.standard_normal((8, 5))
|
|||
|
|
|
|||
|
|
create_test(
|
|||
|
|
"dot_1d",
|
|||
|
|
lambda x, y: core.dot(x, y),
|
|||
|
|
[x_dot, y_dot],
|
|||
|
|
expected_grads=[
|
|||
|
|
y_dot.sum(axis=1), # dL/dx, shape (8,)
|
|||
|
|
np.broadcast_to(x_dot[:, None], (8, 5)).copy(), # dL/dy, shape (8, 5)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("Done.")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|