102 lines
2.7 KiB
Python
102 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) 2026 by David Boetius
|
|
# Licensed under the MIT License.
|
|
"""Generate eval tests for the same operations that expose grad bugs.
|
|
|
|
These eval tests verify that forward evaluation is correct for
|
|
broadcasting mul/add and 1D dot — the same operations whose VJP rules
|
|
are buggy.
|
|
|
|
Run from the repository root:
|
|
python tests/milestone1/base/unit/eval/generate_bugfix_eval_tests.py
|
|
"""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from minijax import core
|
|
from minijax.compute_graph import make_graph
|
|
from minijax.eval import Array
|
|
from minijax.serialize import dump
|
|
|
|
SCRIPT_DIR = Path(__file__).parent
|
|
REPO_ROOT = SCRIPT_DIR.parents[4]
|
|
|
|
DEFAULT_TOLERANCE = 1e-4
|
|
|
|
|
|
def create_test(name, fn, inputs_data, *, tolerance=DEFAULT_TOLERANCE):
|
|
test_dir = SCRIPT_DIR / name
|
|
resources_dir = test_dir / "resources"
|
|
resources_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
arrays = [Array(d) for d in inputs_data]
|
|
graph = make_graph(fn)(*arrays)
|
|
|
|
network_file = f"{name}_network.mininn"
|
|
dump(graph, resources_dir / network_file)
|
|
|
|
input_files = []
|
|
for i, d in enumerate(inputs_data):
|
|
fname = f"input_{i}.bin"
|
|
np.asarray(d, dtype=np.float64).tofile(resources_dir / fname)
|
|
input_files.append(f"resources/{fname}")
|
|
|
|
raw_out = fn(*arrays)
|
|
if not isinstance(raw_out, (list, tuple)):
|
|
raw_out = [raw_out]
|
|
|
|
expected_files = []
|
|
for i, out in enumerate(raw_out):
|
|
fname = f"expected_output_{i}.bin"
|
|
out.array.astype(np.float64).tofile(test_dir / fname)
|
|
expected_files.append(fname)
|
|
|
|
config = {
|
|
"command": "eval",
|
|
"network": f"resources/{network_file}",
|
|
"inputs": input_files,
|
|
"expected_outputs": expected_files,
|
|
}
|
|
if tolerance != DEFAULT_TOLERANCE:
|
|
config["tolerance"] = tolerance
|
|
(test_dir / "test.json").write_text(json.dumps(config, indent=4) + "\n")
|
|
print(f" created {test_dir.relative_to(REPO_ROOT)}")
|
|
|
|
|
|
def main():
|
|
rng = np.random.default_rng(100)
|
|
|
|
# mul with broadcasting: x (3, 4), y (1, 4)
|
|
x_data = rng.standard_normal((3, 4))
|
|
y_data = rng.standard_normal((1, 4))
|
|
create_test(
|
|
"mul_broadcast",
|
|
lambda x, y: core.mul(x, y),
|
|
[x_data, y_data],
|
|
)
|
|
|
|
# add with broadcasting: x (3, 4), y (1, 4)
|
|
create_test(
|
|
"add_broadcast",
|
|
lambda x, y: core.add(x, y),
|
|
[x_data, y_data],
|
|
)
|
|
|
|
# dot with 1D input: x (8,), y (8, 5)
|
|
x_dot = rng.standard_normal((8,))
|
|
y_dot = rng.standard_normal((8, 5))
|
|
create_test(
|
|
"dot_1d",
|
|
lambda x, y: core.dot(x, y),
|
|
[x_dot, y_dot],
|
|
)
|
|
|
|
print("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|