Files
fvnn-minijax/tests/milestone1/base/unit/grad/generate_bugfix_grad_tests.py
T

136 lines
4.6 KiB
Python
Raw Normal View History

2026-05-26 10:24:33 +02:00
#!/usr/bin/env python3
# Copyright (c) 2026 by David Boetius
# Licensed under the MIT License.
"""Generate grad tests that expose known bugs in minijax's gradient logic.
Bug 1 unbroadcast is incomplete (grad.py:73-75):
unbroadcast only removes extra leading dimensions but does not reduce
along size-1 broadcast dimensions. Tests: mul_broadcast, add_broadcast
Bug 2 dot VJP is wrong for 1D inputs (grad.py:90):
transpose(x) @ t computes an inner product instead of an outer product
when both x and t are 1D vectors. Test: dot_1d
Expected gradients are computed analytically with numpy so they are
independent of the buggy _backwards implementation.
Run from the repository root:
python tests/milestone1/base/unit/grad/generate_bugfix_grad_tests.py
"""
import json
from pathlib import Path
import numpy as np
from minijax import core
from minijax.compute_graph import make_graph
from minijax.eval import Array
from minijax.serialize import dump
SCRIPT_DIR = Path(__file__).parent
REPO_ROOT = SCRIPT_DIR.parents[4]
DEFAULT_TOLERANCE = 1e-4
def create_test(name, fn, inputs_data, expected_grads, *, tolerance=DEFAULT_TOLERANCE):
"""Write a grad test whose expected gradients are supplied externally."""
test_dir = SCRIPT_DIR / name
resources_dir = test_dir / "resources"
resources_dir.mkdir(parents=True, exist_ok=True)
arrays = [Array(d) for d in inputs_data]
graph = make_graph(fn)(*arrays)
network_file = f"{name}_network.mininn"
dump(graph, resources_dir / network_file)
input_files = []
for i, d in enumerate(inputs_data):
fname = f"input_{i}.bin"
np.asarray(d, dtype=np.float64).tofile(resources_dir / fname)
input_files.append(f"resources/{fname}")
expected_files = []
for i, grad in enumerate(expected_grads):
fname = f"expected_grad_{i}.bin"
np.asarray(grad, dtype=np.float64).tofile(test_dir / fname)
expected_files.append(fname)
config = {
"command": "grad",
"network": f"resources/{network_file}",
"inputs": input_files,
"expected_outputs": expected_files,
}
if tolerance != DEFAULT_TOLERANCE:
config["tolerance"] = tolerance
(test_dir / "test.json").write_text(json.dumps(config, indent=4) + "\n")
print(f" created {test_dir.relative_to(REPO_ROOT)}")
def main():
# ------------------------------------------------------------------
# Bug 1: unbroadcast incomplete — mul with broadcasting
# f(x, y) = x * y, x shape (3, 4), y shape (1, 4)
# d(sum f)/dx = broadcast(y, (3,4))
# d(sum f)/dy = x.sum(axis=0, keepdims=True) ← requires reducing
# along the size-1 broadcast axis, which unbroadcast misses
# ------------------------------------------------------------------
rng = np.random.default_rng(100)
x_data = rng.standard_normal((3, 4))
y_data = rng.standard_normal((1, 4))
create_test(
"mul_broadcast",
lambda x, y: core.mul(x, y),
[x_data, y_data],
expected_grads=[
np.broadcast_to(y_data, (3, 4)), # dL/dx
x_data.sum(axis=0, keepdims=True), # dL/dy
],
)
# Same bug, simpler case: add with broadcasting
# f(x, y) = x + y, x shape (3, 4), y shape (1, 4)
# d(sum f)/dx = ones(3, 4)
# d(sum f)/dy = ones(1, 4) * 3 (sum of ones along broadcast axis)
create_test(
"add_broadcast",
lambda x, y: core.add(x, y),
[x_data, y_data],
expected_grads=[
np.ones((3, 4)), # dL/dx
np.full((1, 4), 3.0), # dL/dy
],
)
# ------------------------------------------------------------------
# Bug 2: dot VJP wrong for 1D inputs
# f(x, y) = dot(x, y), x shape (8,), y shape (8, 5)
# output shape (5,); out_tangent = ones(5,)
# d(sum f)/dx_i = sum_j y_{ij} → y.sum(axis=1), shape (8,)
# d(sum f)/dy_{ij} = x_i → x[:, None] * ones(1, 5), shape (8, 5)
# The buggy VJP computes transpose(x) @ t = dot((8,), (5,)) which
# raises an error (shape mismatch) or gives a scalar (if sizes equal).
# ------------------------------------------------------------------
x_dot = rng.standard_normal((8,))
y_dot = rng.standard_normal((8, 5))
create_test(
"dot_1d",
lambda x, y: core.dot(x, y),
[x_dot, y_dot],
expected_grads=[
y_dot.sum(axis=1), # dL/dx, shape (8,)
np.broadcast_to(x_dot[:, None], (8, 5)).copy(), # dL/dy, shape (8, 5)
],
)
print("Done.")
if __name__ == "__main__":
main()