Preliminary state
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2026 by David Boetius
|
||||
# Licensed under the MIT License.
|
||||
"""Generate eval tests for the same operations that expose grad bugs.
|
||||
|
||||
These eval tests verify that forward evaluation is correct for
|
||||
broadcasting mul/add and 1D dot — the same operations whose VJP rules
|
||||
are buggy.
|
||||
|
||||
Run from the repository root:
|
||||
python tests/milestone1/base/unit/eval/generate_bugfix_eval_tests.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from minijax import core
|
||||
from minijax.compute_graph import make_graph
|
||||
from minijax.eval import Array
|
||||
from minijax.serialize import dump
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
REPO_ROOT = SCRIPT_DIR.parents[4]
|
||||
|
||||
DEFAULT_TOLERANCE = 1e-4
|
||||
|
||||
|
||||
def create_test(name, fn, inputs_data, *, tolerance=DEFAULT_TOLERANCE):
|
||||
test_dir = SCRIPT_DIR / name
|
||||
resources_dir = test_dir / "resources"
|
||||
resources_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
arrays = [Array(d) for d in inputs_data]
|
||||
graph = make_graph(fn)(*arrays)
|
||||
|
||||
network_file = f"{name}_network.mininn"
|
||||
dump(graph, resources_dir / network_file)
|
||||
|
||||
input_files = []
|
||||
for i, d in enumerate(inputs_data):
|
||||
fname = f"input_{i}.bin"
|
||||
np.asarray(d, dtype=np.float64).tofile(resources_dir / fname)
|
||||
input_files.append(f"resources/{fname}")
|
||||
|
||||
raw_out = fn(*arrays)
|
||||
if not isinstance(raw_out, (list, tuple)):
|
||||
raw_out = [raw_out]
|
||||
|
||||
expected_files = []
|
||||
for i, out in enumerate(raw_out):
|
||||
fname = f"expected_output_{i}.bin"
|
||||
out.array.astype(np.float64).tofile(test_dir / fname)
|
||||
expected_files.append(fname)
|
||||
|
||||
config = {
|
||||
"command": "eval",
|
||||
"network": f"resources/{network_file}",
|
||||
"inputs": input_files,
|
||||
"expected_outputs": expected_files,
|
||||
}
|
||||
if tolerance != DEFAULT_TOLERANCE:
|
||||
config["tolerance"] = tolerance
|
||||
(test_dir / "test.json").write_text(json.dumps(config, indent=4) + "\n")
|
||||
print(f" created {test_dir.relative_to(REPO_ROOT)}")
|
||||
|
||||
|
||||
def main():
|
||||
rng = np.random.default_rng(100)
|
||||
|
||||
# mul with broadcasting: x (3, 4), y (1, 4)
|
||||
x_data = rng.standard_normal((3, 4))
|
||||
y_data = rng.standard_normal((1, 4))
|
||||
create_test(
|
||||
"mul_broadcast",
|
||||
lambda x, y: core.mul(x, y),
|
||||
[x_data, y_data],
|
||||
)
|
||||
|
||||
# add with broadcasting: x (3, 4), y (1, 4)
|
||||
create_test(
|
||||
"add_broadcast",
|
||||
lambda x, y: core.add(x, y),
|
||||
[x_data, y_data],
|
||||
)
|
||||
|
||||
# dot with 1D input: x (8,), y (8, 5)
|
||||
x_dot = rng.standard_normal((8,))
|
||||
y_dot = rng.standard_normal((8, 5))
|
||||
create_test(
|
||||
"dot_1d",
|
||||
lambda x, y: core.dot(x, y),
|
||||
[x_dot, y_dot],
|
||||
)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user