diff --git a/.github/workflows/neuralnet-ci.yml b/.github/workflows/neuralnet-ci.yml new file mode 100644 index 000000000..72ccee02c --- /dev/null +++ b/.github/workflows/neuralnet-ci.yml @@ -0,0 +1,43 @@ +name: neuralnet-ci + +on: + pull_request: + branches: [ main ] + push: + branches: [ main ] + +jobs: + cleanup-before: + uses: ./.github/workflows/_cleanup.yml + with: + when: "before" + + test-neuralnet: + needs: cleanup-before + name: test-neuralnet - ubuntu-latest + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Upgrade pip + run: python -m pip install -U pip + + - name: Install pufferlib + run: | + pip install -e .[cpu] --no-cache-dir + env: + TMPDIR: ${{ runner.temp }}/build + PIP_NO_CACHE_DIR: 1 + + - name: Compile C extensions + run: python setup.py build_ext --inplace --force + + - name: Run Forward pass + run: python tests/test_drivenet.py + timeout-minutes: 15 diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index 1061f33b7..7b6e0f573 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -1,7 +1,14 @@ #include "drive.h" +#include "drivenet.h" +#include + #define Env Drive #define MY_SHARED #define MY_PUT + +static PyObject* test_forward(PyObject* self, PyObject* args, PyObject* kwargs); +#define MY_METHODS {"test_forward", (PyCFunction)test_forward, METH_VARARGS | METH_KEYWORDS, "Test forward pass"} + #include "../env_binding.h" static int my_put(Env* env, PyObject* args, PyObject* kwargs) { @@ -220,3 +227,41 @@ static int my_log(PyObject* dict, Log* log) { assign_to_dict(dict, "avg_collisions_per_agent", log->avg_collisions_per_agent); return 0; } + +static PyObject* test_forward(PyObject* self, PyObject* args, PyObject* kwargs) { + PyObject* obs_obj = PyDict_GetItemString(kwargs, "observations"); + const char* weights_file = unpack_str(kwargs, "weights_file"); + const int dynamics_model = unpack(kwargs, "dynamics_model"); + + PyArrayObject* obs_array = (PyArrayObject*)obs_obj; + int batch_size = PyArray_DIM(obs_array, 0); + float* observations = (float*)PyArray_DATA(obs_array); + + Weights* weights = load_weights(weights_file); + if (!weights) { + PyErr_SetString(PyExc_RuntimeError, "Failed to load weights"); + return NULL; + } + + DriveNet* net = init_drivenet(weights, batch_size, dynamics_model); + + npy_intp action_dims[2] = {batch_size, 2}; + npy_intp logit_dims[2] = {batch_size, 20}; // 20 = 7 + 13 (steering + speed logits) + + PyObject* actions_array = PyArray_SimpleNew(2, action_dims, NPY_INT32); + PyObject* logits_array = PyArray_SimpleNew(2, logit_dims, NPY_FLOAT32); + + int* actions = (int*)PyArray_DATA((PyArrayObject*)actions_array); + float* logits = (float*)PyArray_DATA((PyArrayObject*)logits_array); + + forward(net, observations, actions); + memcpy(logits, net->actor->output, batch_size * 20 * sizeof(float)); + + free_drivenet(net); + free(weights); + + PyObject* result = PyTuple_New(2); + PyTuple_SetItem(result, 0, actions_array); + PyTuple_SetItem(result, 1, logits_array); + return result; +} diff --git a/pufferlib/resources/drive/puffer_drive_weights.bin b/pufferlib/resources/drive/puffer_drive_weights.bin index 29737db60..4cad564ef 100644 Binary files a/pufferlib/resources/drive/puffer_drive_weights.bin and b/pufferlib/resources/drive/puffer_drive_weights.bin differ diff --git a/pufferlib/resources/drive/puffer_drive_weights.pt b/pufferlib/resources/drive/puffer_drive_weights.pt new file mode 100644 index 000000000..895bd8e44 Binary files /dev/null and b/pufferlib/resources/drive/puffer_drive_weights.pt differ diff --git a/setup.py b/setup.py index e3d86c0e6..86c08f1f1 100644 --- a/setup.py +++ b/setup.py @@ -253,6 +253,7 @@ def run(self): for c_ext in c_extensions: if "drive" in c_ext.name: c_ext.sources.append("inih-r62/ini.c") + c_ext.include_dirs.append("pufferlib/extensions") c_ext.extra_compile_args.extend( [ '-DINI_START_COMMENT_PREFIXES="#"', diff --git a/tests/test_drivenet.py b/tests/test_drivenet.py new file mode 100644 index 000000000..a4f5ce3a0 --- /dev/null +++ b/tests/test_drivenet.py @@ -0,0 +1,70 @@ +import os +import sys +import numpy as np +import torch + +from pufferlib.ocean.torch import Drive, Recurrent +from pufferlib.ocean import env_creator +from pufferlib.ocean.drive import binding + + +def test_drivenet( + pt_file="resources/drive/puffer_drive_weights.pt", + bin_file="resources/drive/puffer_drive_weights.bin", + batch_size=4, + seed=42, +): + """Compare logits from PyTorch and C implementations.""" + + assert os.path.exists(bin_file), f"{bin_file} not found" + assert os.path.exists(pt_file), f"{pt_file} not found" + + env = env_creator("puffer_drive")(num_maps=1, num_agents=batch_size, scenario_length=91) + policy = Drive(env, input_size=64, hidden_size=256) + model = Recurrent(env, policy=policy, input_size=256, hidden_size=256) + + state_dict = torch.load(pt_file, map_location="cpu") + model.load_state_dict(state_dict) + model.eval() + + np.random.seed(seed) + torch.manual_seed(seed) + obs = np.random.randn(batch_size, env.num_obs).astype(np.float32) + + # Categorical road type features must be integers 0-6 + road_start = 7 + 63 * 7 + for i in range(200): + obs[:, road_start + i * 7 + 6] = np.random.randint(0, 7, size=batch_size) + + with torch.no_grad(): + lstm_state = { + "lstm_h": torch.zeros(1, batch_size, 256), + "lstm_c": torch.zeros(1, batch_size, 256), + } + actions_torch, _ = model.forward(torch.from_numpy(obs), lstm_state) + + logits_torch = torch.cat(actions_torch, dim=1).cpu().numpy() + + # C forward pass + _, logits_c = binding.test_forward(observations=obs, weights_file=bin_file, dynamics_model=0) + + diff = np.abs(logits_torch - logits_c) + max_diff = diff.max() + mean_diff = diff.mean() + + print(f"First batch:") + print(f"Logits PyTorch: {logits_torch[0, :10]}") + print(f"Logits C: {logits_c[0, :10]}") + print(f" Diff: {diff[0, :10]}") + print(f" Max difference: {max_diff:.6f}") + print(f" Mean difference: {mean_diff:.6f}") + + if max_diff < 1e-2: + return True + else: + return False + + +if __name__ == "__main__": + success = test_drivenet() + sys.exit(0 if success else 1)