Preliminary state

This commit is contained in:
Anton Pogrebnjak
2026-05-26 10:24:33 +02:00
commit 0ab11cfc5e
950 changed files with 6428 additions and 0 deletions
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:658963e7fce1d5833471359fe7e6ad1211149d6af29f53ed5a7bf91c8065a37f
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:889b31ca546349820f688f2a10d0e81554f557e62f7cf9fa534fd3f573f5840c
size 175
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e4271cf02dceda71d3271d412d0610e2b523e437be779ba24f93ce41f863d0e3
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cdb519e8c06e9c7c5040ca545bd46afb969c6ac83216fab0ac8b0e0f045141d6
size 400
@@ -0,0 +1,11 @@
{
"command": "eval",
"network": "resources/add_network.mininn",
"inputs": [
"resources/input_0.bin",
"resources/input_1.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d6564733709d7dd0bc46c04261b5f646e09579740a71a689ff1930eef5bb0ef8
size 96
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:884273b6e5c7f60c6af52f730f7a094fa8a8ca41d932dc5f7403870f6c324d20
size 181
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:480286cdfa2dfbdcf7be15c1fd71c23397e6a52f55ebcee6fb1f9977ce9f44bc
size 96
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e11c34d76571d9e1dd6e2418f7048f030fdfce7dd41b91ba704694d83b21ac2f
size 32
@@ -0,0 +1,11 @@
{
"command": "eval",
"network": "resources/add_broadcast_network.mininn",
"inputs": [
"resources/input_0.bin",
"resources/input_1.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:74cdf61be260f12a55c2632a2f3203c999f528f87ffd06bc3726b73dde088483
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2d17711d7c566ef21a4dc1fa60865fc184b2dacf8cf0975a87081e2efab1c2f9
size 185
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:efe718831c28c8904df28ca06e2e36695c302cf826a86f6c6c20a024ec94a07f
size 640
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e068ccc39f44e8bf19764023f6636fe4bdf4716254f58e113e067963cc80f3cd
size 320
@@ -0,0 +1,11 @@
{
"command": "eval",
"network": "resources/dot_network.mininn",
"inputs": [
"resources/input_0.bin",
"resources/input_1.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4abb955095b89e04998f2c0d1c18571bd9430ea05afd90c585891fec657dfc40
size 40
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eb011a372c64fa0f0a71adc26a939fd75c31ab624c27a4381258d83023208719
size 173
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:829530ea5e098c191c28de7cbf525f84d8da0f24e2be9be46cca22a5ec9a6521
size 64
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c3635fb30b26b35e98f213b4460e88fc5d3e565b94e6c1eca0c921a923bdc744
size 320
@@ -0,0 +1,11 @@
{
"command": "eval",
"network": "resources/dot_1d_network.mininn",
"inputs": [
"resources/input_0.bin",
"resources/input_1.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f9dea9396128300cb2f4fc377d82176a7818bc990faa4e0ce995613e40029d2
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:76b5c151d06d8c5489170edc7c4c3b41778a893679370abfa9c2e257817b6ee3
size 162
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1c119bca6e9773d06de9e7132fb0638aa41eb4e1ae9dc706775e2e4be1013a1a
size 400
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/exp_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03f36663d9df8ac128397cb636c613de63253cd8457146030f4061ae3da29027
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3470b2998b6c00e26bf61323824ad6741f8f97bf72082826a39f56b79f0302b2
size 194
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03f36663d9df8ac128397cb636c613de63253cd8457146030f4061ae3da29027
size 800
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/expand_dims_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0b75c050039d34523d371215c5a0391c2d01457fd14cc44fb3a2cf306bcb46ef
size 8000
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e621b6023d483a1d167182ffcce67ec102006507f260ccdd20d1abb8024f391b
size 82510
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1b0efeefd5d5890b2445931ddba9fe5a52f944ec6f77bb4344c9ecd7a626733c
size 16000
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_10x32_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9c34fcbdae56d64bd7b2b9c8eeae9d42f8bd6a2bd89fa2126e28590f35ea2e91
size 6000
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2383b140624943fed01811d154b3c870b0ad55e106969fbff10c477645a3be3e
size 5135
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9b7f2224a5e4a02b05d5422ba4bcd34eab9b319bf1651ba4added2135860509c
size 32000
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_10x4_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc4f2b551bbec51232eb00f3c9efc17da40361f192834f2a88ea706e2fca50a9
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc91ab8a8217f3b56132e5c95fde4160208b4ce6671d5d5c0d3d28ec76dbd5c7
size 1228
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7d10261596b5fcd9d5945a3f7129172ef9bf46b68244bc03611b51dfea3e4c22
size 800
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_2x4_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1d0f53769d4796e024db0701164fe04b6e821bcf1e7f99a07a4b18480379143f
size 1600
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32cb3bc14ce820f73a7cc09bc1472dfbe64030fec5fcc9e2d668935514143037
size 36784
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:26b8939ad4478bbcf7b8abc0c59c5f60021b408347c207678239ba6076d68d92
size 1600
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_2x64_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a5abc9c3916dc722a8d105f80fa40f476c20c5663a706b55e6130f6e3907391c
size 2400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f71d8e195bcd6e69631ba4a6dc519d25e21980fbc0d9b3d2cf28e0b0c9293a91
size 6600
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:19626383235794d85f1fc04833eda488bf280b19be13f00074accff36e724dc3
size 3200
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_3x16_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cf0df70b3665ca4261ccb233a2f86bd0e39b0bf248c9a71d14b1bb87a736d067
size 6400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1158f00bf18d0394f92d5d9fe4ec3fe75b4e51f1d4b0deb7cb188fb4c3772dd3
size 76499
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ad823984b2980ef06c4acc9d70f0d447e3c522657c16c67854fc432a6f49e395
size 19200
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_3x64_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:78721334288595e0175cbf96e3af7f6e56c0723e7647674a0787c13f8c354af3
size 6000
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a324b699f2cc6ed1e60eded830c84bb37c7a890b839dc3b0ec25513271505628
size 38136
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2144b76cc35761b8878eec253698eec94a6266bf7454c111f1f0a6182ac4c489
size 4800
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_5x32_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:37cf8339c1b94f614239f9bab0d64e66a6a748c8f864fea767430773ea6627df
size 1600
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e6c26ad0a431d7bc6f9865e16d6991131a744a2aee6aa121bf7ebd8b5b4bd1d0
size 2515
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:df17b853a8f193c24bdaaef0695314ea5b81c9706a3e392636aa56ea2b1b8418
size 1600
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_5x4_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:90986832f0a9f2b7f0043cdc14680789cb293da672d60677546c65ac53fd0385
size 4800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:360c3b254153829cbd41e84f2d8a7fae83eec783ae64f57742a6bf7be03c1701
size 16904
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5fd8cc59f733900cc8596dd2cbf46a3755afae8b42c42d79940a42d8b48af5b4
size 12800
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/fc_7x16_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,101 @@
#!/usr/bin/env python3
# Copyright (c) 2026 by David Boetius
# Licensed under the MIT License.
"""Generate eval tests for the same operations that expose grad bugs.
These eval tests verify that forward evaluation is correct for
broadcasting mul/add and 1D dot — the same operations whose VJP rules
are buggy.
Run from the repository root:
python tests/milestone1/base/unit/eval/generate_bugfix_eval_tests.py
"""
import json
from pathlib import Path
import numpy as np
from minijax import core
from minijax.compute_graph import make_graph
from minijax.eval import Array
from minijax.serialize import dump
SCRIPT_DIR = Path(__file__).parent
REPO_ROOT = SCRIPT_DIR.parents[4]
DEFAULT_TOLERANCE = 1e-4
def create_test(name, fn, inputs_data, *, tolerance=DEFAULT_TOLERANCE):
test_dir = SCRIPT_DIR / name
resources_dir = test_dir / "resources"
resources_dir.mkdir(parents=True, exist_ok=True)
arrays = [Array(d) for d in inputs_data]
graph = make_graph(fn)(*arrays)
network_file = f"{name}_network.mininn"
dump(graph, resources_dir / network_file)
input_files = []
for i, d in enumerate(inputs_data):
fname = f"input_{i}.bin"
np.asarray(d, dtype=np.float64).tofile(resources_dir / fname)
input_files.append(f"resources/{fname}")
raw_out = fn(*arrays)
if not isinstance(raw_out, (list, tuple)):
raw_out = [raw_out]
expected_files = []
for i, out in enumerate(raw_out):
fname = f"expected_output_{i}.bin"
out.array.astype(np.float64).tofile(test_dir / fname)
expected_files.append(fname)
config = {
"command": "eval",
"network": f"resources/{network_file}",
"inputs": input_files,
"expected_outputs": expected_files,
}
if tolerance != DEFAULT_TOLERANCE:
config["tolerance"] = tolerance
(test_dir / "test.json").write_text(json.dumps(config, indent=4) + "\n")
print(f" created {test_dir.relative_to(REPO_ROOT)}")
def main():
rng = np.random.default_rng(100)
# mul with broadcasting: x (3, 4), y (1, 4)
x_data = rng.standard_normal((3, 4))
y_data = rng.standard_normal((1, 4))
create_test(
"mul_broadcast",
lambda x, y: core.mul(x, y),
[x_data, y_data],
)
# add with broadcasting: x (3, 4), y (1, 4)
create_test(
"add_broadcast",
lambda x, y: core.add(x, y),
[x_data, y_data],
)
# dot with 1D input: x (8,), y (8, 5)
x_dot = rng.standard_normal((8,))
y_dot = rng.standard_normal((8, 5))
create_test(
"dot_1d",
lambda x, y: core.dot(x, y),
[x_dot, y_dot],
)
print("Done.")
if __name__ == "__main__":
main()
@@ -0,0 +1,267 @@
#!/usr/bin/env python3
# Copyright (c) 2026 by David Boetius
# Licensed under the MIT License.
"""Generate eval tests for trained neural network classifiers.
Networks are trained on Gaussian blob datasets using minijax.
Three families of networks are generated:
- Shallow: single hidden layer, widths 4128
- Deep FC: uniform-width MLPs with 210 layers, widths 464
- Residual: same as deep FC with skip connections on hidden layers
Run from the repository root:
python tests/milestone1/base/unit/eval/generate_networks.py
"""
import json
from pathlib import Path
import numpy as np
from minijax import core, nn
from minijax.compute_graph import make_graph
from minijax.eval import Array
from minijax.grad import _backwards, _forward
from minijax.nested_containers import flatten, map_structure, unflatten
from minijax.serialize import dump
SCRIPT_DIR = Path(__file__).parent
REPO_ROOT = SCRIPT_DIR.parents[4]
# ======================================================================
# Batch-compatible network definitions
# ======================================================================
def mlp_batched(x, params):
"""MLP on batched input x of shape (batch, features)."""
for p in params[:-1]:
x = x @ p["weight"] + p["bias"]
x = core.relu(x)
return x @ params[-1]["weight"] + params[-1]["bias"]
def residual_mlp_batched(x, params):
"""MLP with residual connections on batched input x of shape (batch, features).
Architecture:
params[0]: input projection (in_dim → width), followed by ReLU
params[1:-1]: residual blocks (width → width), each relu(linear(x)) + x
params[-1]: output projection (width → n_classes), no activation
"""
x = core.relu(x @ params[0]["weight"] + params[0]["bias"])
for p in params[1:-1]:
residual = x
x = core.relu(x @ p["weight"] + p["bias"])
x = x + residual
return x @ params[-1]["weight"] + params[-1]["bias"]
# ======================================================================
# Training with compiled gradients
# ======================================================================
def _make_compiled_value_and_grad(fn, example_primals):
"""Compile fn+gradient once; return a function that re-runs them cheaply.
example_primals must have the same structure and shapes as all future
calls. The graph is traced once and then executed with _forward/_backwards
on every subsequent invocation, avoiding repeated Python-level tracing.
"""
cg = make_graph(fn)(*example_primals)
_, in_structure = flatten(example_primals)
def v_and_g(*primals):
flat_primals, _ = flatten(primals)
primals_dict = _forward(cg, list(flat_primals))
loss_val = primals_dict[cg.outvars[0]]
in_tangents = _backwards(cg, primals_dict, [Array(1.0)])
grads = unflatten(in_structure, in_tangents)
return loss_val, grads
return v_and_g
def train(
X_arr, Y_arr, layer_sizes, n_classes,
*, n_epochs, lr, weight_decay_coef, rng_key, residual=False, init_scale=1.0,
):
"""Train a classifier with SGD on one-hot targets Y_arr.
Returns the trained params (list of {weight, bias} dicts).
init_scale multiplies the Kaiming-uniform weight initialisation; use a
value < 1 for very deep networks to prevent overflow in the softmax.
"""
params = nn.init_mlp(X_arr.shape[1], layer_sizes + [n_classes], rng_key)
if init_scale != 1.0:
params = [
{"weight": Array(p["weight"].array * init_scale), "bias": p["bias"]}
for p in params
]
forward = residual_mlp_batched if residual else mlp_batched
def loss_fn(params):
logits = forward(X_arr, params)
return nn.cross_entropy(logits, Y_arr) + Array(weight_decay_coef) * nn.weight_decay(params)
v_and_g = _make_compiled_value_and_grad(loss_fn, (params,))
for _ in range(n_epochs):
_, (grads,) = v_and_g(params)
params = map_structure(
lambda p, g: Array(p.array - lr * g.array),
params, grads,
)
return params
# ======================================================================
# Data generation
# ======================================================================
def make_blobs(n_features, n_classes, *, rng_seed):
"""Gaussian blob classification dataset (n_classes blobs, 200 pts each)."""
rng = np.random.default_rng(rng_seed)
n_per_class = 200
centers = rng.uniform(-3.0, 3.0, (n_classes, n_features))
X_parts, y_parts = [], []
for i, center in enumerate(centers):
X_parts.append(rng.normal(center, 0.7, (n_per_class, n_features)))
y_parts.append(np.full(n_per_class, i, dtype=np.int64))
X = np.vstack(X_parts).astype(np.float64)
y = np.concatenate(y_parts)
Y = np.eye(n_classes, dtype=np.float64)[y]
return X, Y
def make_linspace_inputs(X_train, n_points):
"""n_points test inputs covering the per-feature range of X_train.
Each row traverses a linear path from the minimum to the maximum of the
training data for each feature dimension independently.
"""
mins = X_train.min(axis=0)
maxs = X_train.max(axis=0)
t = np.linspace(0.0, 1.0, n_points)
return (mins + np.outer(t, maxs - mins)).astype(np.float64)
# ======================================================================
# Test creation (same pattern as generate.py)
# ======================================================================
DEFAULT_TOLERANCE = 1e-4
def create_test(name, fn, inputs_data, *, tolerance=DEFAULT_TOLERANCE):
test_dir = SCRIPT_DIR / name
resources_dir = test_dir / "resources"
resources_dir.mkdir(parents=True, exist_ok=True)
arrays = [Array(d) for d in inputs_data]
graph = make_graph(fn)(*arrays)
network_file = f"{name}_network.mininn"
dump(graph, resources_dir / network_file)
input_files = []
for i, d in enumerate(inputs_data):
fname = f"input_{i}.bin"
np.asarray(d, dtype=np.float64).tofile(resources_dir / fname)
input_files.append(f"resources/{fname}")
raw_out = fn(*arrays)
if not isinstance(raw_out, (list, tuple)):
raw_out = [raw_out]
expected_files = []
for i, out in enumerate(raw_out):
fname = f"expected_output_{i}.bin"
out.array.astype(np.float64).tofile(test_dir / fname)
expected_files.append(fname)
config = {
"command": "eval",
"network": f"resources/{network_file}",
"inputs": input_files,
"expected_outputs": expected_files,
}
if tolerance != DEFAULT_TOLERANCE:
config["tolerance"] = tolerance
(test_dir / "test.json").write_text(json.dumps(config, indent=4) + "\n")
print(f" created {test_dir.relative_to(REPO_ROOT)}")
def build_test(
name, layer_sizes, in_dim, n_classes,
*, n_epochs, lr, n_inputs,
weight_decay_coef=1e-4, residual=False, init_scale=1.0, rng_seed=42,
):
"""Train a network and generate a test directory for it."""
X, Y = make_blobs(in_dim, n_classes, rng_seed=rng_seed * 13 + in_dim)
# Standardise inputs so deep networks train without overflow.
X_mean = X.mean(axis=0)
X_std = X.std(axis=0) + 1e-8
X = (X - X_mean) / X_std
X_arr, Y_arr = Array(X), Array(Y)
params = train(
X_arr, Y_arr, layer_sizes, n_classes,
n_epochs=n_epochs, lr=lr, weight_decay_coef=weight_decay_coef,
rng_key=rng_seed, residual=residual, init_scale=init_scale,
)
forward = residual_mlp_batched if residual else mlp_batched
def network(x):
return forward(x, params)
x_test = make_linspace_inputs(X, n_inputs)
create_test(name, network, [x_test])
# ======================================================================
# Test configurations
# ======================================================================
def main():
print("Generating shallow network tests…")
# Single hidden layer, widths 4128; input dims 216
build_test("shallow_4", [4], in_dim=2, n_classes=2, n_epochs=500, lr=0.01, n_inputs=50, rng_seed=42)
build_test("shallow_8", [8], in_dim=2, n_classes=3, n_epochs=500, lr=0.01, n_inputs=100, rng_seed=43)
build_test("shallow_16", [16], in_dim=4, n_classes=3, n_epochs=500, lr=0.01, n_inputs=100, rng_seed=44)
build_test("shallow_32", [32], in_dim=8, n_classes=4, n_epochs=500, lr=0.01, n_inputs=150, rng_seed=45)
build_test("shallow_64", [64], in_dim=12, n_classes=5, n_epochs=500, lr=0.01, n_inputs=200, rng_seed=46)
build_test("shallow_128", [128], in_dim=16, n_classes=3, n_epochs=500, lr=0.01, n_inputs=250, rng_seed=47)
print("\nGenerating deep FC network tests…")
# Uniform-width MLPs, 210 layers, widths 464; input dims 216
build_test("fc_2x4", [4] * 2, in_dim=2, n_classes=2, n_epochs=600, lr=0.01, n_inputs=50, rng_seed=50)
build_test("fc_2x64", [64] * 2, in_dim=2, n_classes=2, n_epochs=600, lr=0.01, n_inputs=100, rng_seed=51)
build_test("fc_3x16", [16] * 3, in_dim=4, n_classes=3, n_epochs=600, lr=0.01, n_inputs=100, rng_seed=52)
build_test("fc_3x64", [64] * 3, in_dim=12, n_classes=4, n_epochs=600, lr=0.01, n_inputs=200, rng_seed=53)
build_test("fc_5x4", [4] * 5, in_dim=2, n_classes=2, n_epochs=800, lr=0.01, n_inputs=100, rng_seed=54)
build_test("fc_5x32", [32] * 5, in_dim=4, n_classes=5, n_epochs=800, lr=0.01, n_inputs=150, rng_seed=55)
build_test("fc_7x16", [16] * 7, in_dim=8, n_classes=3, n_epochs=800, lr=0.01, n_inputs=200, init_scale=0.1, rng_seed=56)
build_test("fc_10x4", [4] * 10, in_dim=16, n_classes=3, n_epochs=1000, lr=0.005, n_inputs=250, init_scale=0.1, rng_seed=57)
build_test("fc_10x32", [32] * 10, in_dim=8, n_classes=4, n_epochs=1000, lr=0.005, n_inputs=250, init_scale=0.1, rng_seed=58)
print("\nGenerating residual FC network tests…")
# Same architecture family, residual skip connections on hidden layers
build_test("residual_2x64", [64] * 2, in_dim=2, n_classes=2, n_epochs=600, lr=0.01, n_inputs=100, residual=True, rng_seed=60)
build_test("residual_3x16", [16] * 3, in_dim=4, n_classes=3, n_epochs=600, lr=0.01, n_inputs=100, residual=True, rng_seed=61)
build_test("residual_3x64", [64] * 3, in_dim=12, n_classes=4, n_epochs=600, lr=0.01, n_inputs=200, residual=True, rng_seed=62)
build_test("residual_5x4", [4] * 5, in_dim=2, n_classes=2, n_epochs=800, lr=0.01, n_inputs=100, residual=True, rng_seed=63)
build_test("residual_5x32", [32] * 5, in_dim=4, n_classes=5, n_epochs=800, lr=0.01, n_inputs=150, residual=True, rng_seed=64)
build_test("residual_7x16", [16] * 7, in_dim=8, n_classes=3, n_epochs=800, lr=0.01, n_inputs=200, residual=True, init_scale=0.1, rng_seed=65)
build_test("residual_10x4", [4] * 10, in_dim=16, n_classes=3, n_epochs=1000, lr=0.002, n_inputs=250, residual=True, init_scale=0.1, rng_seed=66)
build_test("residual_10x32", [32] * 10, in_dim=8, n_classes=4, n_epochs=1000, lr=0.002, n_inputs=250, residual=True, init_scale=0.1, rng_seed=67)
print("\nDone.")
if __name__ == "__main__":
main()
@@ -0,0 +1,151 @@
#!/usr/bin/env python3
# Copyright (c) 2026 by David Boetius
# Licensed under the MIT License.
"""Generate batched eval tests for all minijax primitives.
Run from the repository root:
python tests/milestone1/base/unit/eval/generate_primitive_eval_tests.py
"""
import json
from pathlib import Path
import numpy as np
from minijax import core # noqa: E402
from minijax.compute_graph import make_graph # noqa: E402
from minijax.eval import Array # noqa: E402
from minijax.serialize import dump # noqa: E402
SCRIPT_DIR = Path(__file__).parent
REPO_ROOT = SCRIPT_DIR.parents[4]
DEFAULT_TOLERANCE = 1e-4
def create_test(name, fn, inputs_data, *, tolerance=DEFAULT_TOLERANCE):
test_dir = SCRIPT_DIR / name
resources_dir = test_dir / "resources"
resources_dir.mkdir(parents=True, exist_ok=True)
arrays = [Array(d) for d in inputs_data]
graph = make_graph(fn)(*arrays)
network_file = f"{name}_network.mininn"
dump(graph, resources_dir / network_file)
input_files = []
for i, d in enumerate(inputs_data):
fname = f"input_{i}.bin"
np.asarray(d, dtype=np.float64).tofile(resources_dir / fname)
input_files.append(f"resources/{fname}")
raw_out = fn(*arrays)
if not isinstance(raw_out, (list, tuple)):
raw_out = [raw_out]
expected_files = []
for i, out in enumerate(raw_out):
fname = f"expected_output_{i}.bin"
out.array.astype(np.float64).tofile(test_dir / fname)
expected_files.append(fname)
config = {
"command": "eval",
"network": f"resources/{network_file}",
"inputs": input_files,
"expected_outputs": expected_files,
}
if tolerance != DEFAULT_TOLERANCE:
config["tolerance"] = tolerance
(test_dir / "test.json").write_text(json.dumps(config, indent=4) + "\n")
print(f" created {test_dir.relative_to(REPO_ROOT)}")
def main():
# neg: 100 values, shape (100,)
create_test("neg", lambda x: core.neg(x), [np.linspace(-5.0, 5.0, 100)])
# add: two inputs, shape (50,) each
create_test(
"add", lambda x, y: core.add(x, y), [np.linspace(-5.0, 5.0, 50), np.linspace(0.0, 10.0, 50)]
)
# mul: two inputs, shape (50,) each
create_test(
"mul", lambda x, y: core.mul(x, y), [np.linspace(-5.0, 5.0, 50), np.linspace(0.0, 10.0, 50)]
)
# dot: batched matrix multiply — 10 samples of 8 features → 5 outputs
# x: (10, 8) = 80 values; y: (8, 5) = 40 values → output (10, 5)
create_test(
"dot",
lambda x, y: core.dot(x, y),
[np.linspace(-1.0, 1.0, 80).reshape(10, 8), np.linspace(-1.0, 1.0, 40).reshape(8, 5)],
)
# reciprocal: 50 strictly-positive values, shape (50,)
create_test("reciprocal", lambda x: core.reciprocal(x), [np.linspace(0.5, 5.0, 50)])
# relu: 100 values spanning negative and positive, shape (100,)
create_test("relu", lambda x: core.relu(x), [np.linspace(-5.0, 5.0, 100)])
# square: 100 values, shape (10, 10)
create_test("square", lambda x: core.square(x), [np.linspace(-5.0, 5.0, 100).reshape(10, 10)])
# sqrt: 100 strictly-positive values, shape (10, 10)
create_test("sqrt", lambda x: core.sqrt(x), [np.linspace(0.01, 5.0, 100).reshape(10, 10)])
# exp: 50 values, shape (50,) — moderate range to avoid overflow
create_test("exp", lambda x: core.exp(x), [np.linspace(-3.0, 3.0, 50)])
# log: 100 strictly-positive values, shape (10, 10)
create_test("log", lambda x: core.log(x), [np.linspace(0.01, 10.0, 100).reshape(10, 10)])
# where: 3 inputs each shape (50,)
# condition: 1.0 for even indices, 0.0 for odd (truthy/falsy)
create_test(
"where",
lambda c, x, y: core.where(c, x, y),
[
np.where(np.arange(50) % 2 == 0, 1.0, 0.0),
np.linspace(-5.0, 5.0, 50),
np.linspace(5.0, -5.0, 50),
],
)
# expand_dims: shape (10, 10) → (10, 10, 1), new axis at -1
create_test(
"expand_dims",
lambda x: core.expand_dims(x, axes=-1),
[np.linspace(-5.0, 5.0, 100).reshape(10, 10)],
)
# moveaxis: shape (10, 5, 2) → (5, 2, 10), move axis 0 to -1
create_test(
"moveaxis",
lambda x: core.moveaxis(x, source=0, destination=-1),
[np.linspace(-5.0, 5.0, 100).reshape(10, 5, 2)],
)
# reshape: shape (10, 10) → (20, 5)
create_test(
"reshape",
lambda x: core.reshape(x, new_shape=(20, 5)),
[np.linspace(-5.0, 5.0, 100).reshape(10, 10)],
)
# reduce_sum: shape (10, 10), sum over axis 1 → (10,)
create_test(
"reduce_sum",
lambda x: core.reduce_sum(x, axes=1),
[np.linspace(-5.0, 5.0, 100).reshape(10, 10)],
)
print("Done.")
if __name__ == "__main__":
main()
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8f6397b053a731e9a07c71601d5475355158c8100fcd438c2acaf86bc18f792f
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:255e3c9b0b8ae94ecb0e93df205acde17dbcf2570f8e9a3a87b63e9c1ac5e01e
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:56bc17a18153746fbc88b64ad87745490cb21d9854379261511bcfbb0eb6f0fb
size 174
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/log_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cb1a88d9f67def4ee21df2bf46f82a805e57b37c19c1588df92f5c825fb62e04
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03f36663d9df8ac128397cb636c613de63253cd8457146030f4061ae3da29027
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:60975bf29f7d0498389101fb7704f70ee66e64d12728909d14fd9956ed15933c
size 209
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/moveaxis_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3e302f018394e83ebf9ba5a6801a4f10dda09f4d5b26e9f03424eff03ab4cc6e
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e4271cf02dceda71d3271d412d0610e2b523e437be779ba24f93ce41f863d0e3
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cdb519e8c06e9c7c5040ca545bd46afb969c6ac83216fab0ac8b0e0f045141d6
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1e1570c31284cacb17700213c763d4f46c324219aed2fdf63aeea54d8dd46c3
size 175
@@ -0,0 +1,11 @@
{
"command": "eval",
"network": "resources/mul_network.mininn",
"inputs": [
"resources/input_0.bin",
"resources/input_1.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bee0ce33e2ac6078db77753e8ff4495e09771ea8059027b74d34cb933a1efca9
size 96
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:480286cdfa2dfbdcf7be15c1fd71c23397e6a52f55ebcee6fb1f9977ce9f44bc
size 96
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e11c34d76571d9e1dd6e2418f7048f030fdfce7dd41b91ba704694d83b21ac2f
size 32
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:504d2197eb20290ecb140489e55bcc60154c1bbf2762a2e61079bd344c3da217
size 181
@@ -0,0 +1,11 @@
{
"command": "eval",
"network": "resources/mul_broadcast_network.mininn",
"inputs": [
"resources/input_0.bin",
"resources/input_1.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d4bf4d441a64a58347aa59ffe0488975e3d55b1bde05ac6f1280040bd6c25247
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03f36663d9df8ac128397cb636c613de63253cd8457146030f4061ae3da29027
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:221a991df389e22fd804e62cc66b4affc2dd5c7d7d8022826ae476cdc72839b2
size 166
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/neg_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c4d81f2d318a9429f4a308dbcc6a5e54e7b73ee79ae35449b5d411e910287cf0
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b18e1f9066c263958700abbf32e49538a91ad352b45a236e5246f4a7b1b22674
size 400
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc9f4f4be0a08d32cf5a1a582bdac2cb145b704dfd2d460279622cca0e2649e8
size 169
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/reciprocal_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e6d0e72c80753bbdf7d107c634477c629f03794c3ca2f44d42c5cb73cd0bb33c
size 80
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03f36663d9df8ac128397cb636c613de63253cd8457146030f4061ae3da29027
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:381017a0da33b0cf69a9ae17e19af366a823ec534f78b4cf467735bf1c3570b0
size 185
@@ -0,0 +1,10 @@
{
"command": "eval",
"network": "resources/reduce_sum_network.mininn",
"inputs": [
"resources/input_0.bin"
],
"expected_outputs": [
"expected_output_0.bin"
]
}
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3f5443ff02a63605239d3d53c8117e8d00ad23a48b4578bb10031e31d423fdcb
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03f36663d9df8ac128397cb636c613de63253cd8457146030f4061ae3da29027
size 800
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fbf0ce31ed14ca32d6eabe65709512ec4108e92247a76f47f907ee0584b8a46d
size 167

Some files were not shown because too many files have changed in this diff Show More