diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 56da0e5..85c3148 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -4,10 +4,10 @@ on: [push, pull_request] jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: - python-version: ['3.8', '3.9', '3.10'] + python-version: ['3.10', '3.12'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/cla-check.yaml b/.github/workflows/cla-check.yaml index 95f92ab..cded220 100644 --- a/.github/workflows/cla-check.yaml +++ b/.github/workflows/cla-check.yaml @@ -4,9 +4,7 @@ on: [pull_request] jobs: cla-check: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: Check if CLA signed - uses: canonical/has-signed-canonical-cla@v1 - with: - accept-existing-contributors: true + uses: canonical/has-signed-canonical-cla@v2 diff --git a/debian/changelog b/debian/changelog index bd937c9..07853bf 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,11 @@ +probert (0.0.20ubuntu0~24.04) UNRELEASED; urgency=medium + + * Stop leaning on modules implemented in C, backport the pyroute2 + implementation from core24 instead (LP: #2139131). + * Drop tests and tests data from .deb packages + + -- Olivier Gayot Fri, 23 Jan 2026 11:51:59 +0100 + probert (0.0.20) groovy; urgency=medium [ Ryan Harper ] diff --git a/debian/control b/debian/control index 43b2ce4..3bf4b6b 100644 --- a/debian/control +++ b/debian/control @@ -4,14 +4,12 @@ Priority: optional Maintainer: Ubuntu Developers Build-Depends: debhelper-compat (= 13), dh-python, - libnl-genl-3-dev, - libnl-route-3-dev, - pkg-config, - python3-all-dev, + python3-all, python3-coverage, python3-flake8, python3-jsonschema, python3-nose, + python3-pyroute2 , python3-pyudev, python3-setuptools, Standards-Version: 4.5.0 @@ -31,7 +29,7 @@ Description: Hardware probing tool - metapackage Package: probert-common Architecture: all -Depends: ${misc:Depends}, ${python3:Depends}, ${shlibs:Depends} +Depends: ${misc:Depends}, ${python3:Depends} Breaks: probert (<< 0.0.16) Replaces: probert (<< 0.0.16) Description: Hardware probing tool - common @@ -51,7 +49,6 @@ Depends: bcache-tools, zfsutils-linux, ${misc:Depends}, ${python3:Depends}, - ${shlibs:Depends} Breaks: probert (<< 0.0.16) Replaces: probert (<< 0.0.16) Description: Hardware probing tool - storage probing @@ -61,11 +58,10 @@ Description: Hardware probing tool - storage probing This package contains storage probing capability. Package: probert-network -Architecture: any +Architecture: all Depends: probert-common (= ${source:Version}), ${misc:Depends}, ${python3:Depends}, - ${shlibs:Depends} Breaks: probert (<< 0.0.16) Replaces: probert (<< 0.0.16) Description: Hardware probing tool - network probing diff --git a/debian/probert-common.install b/debian/probert-common.install index b4706c6..1edd98f 100644 --- a/debian/probert-common.install +++ b/debian/probert-common.install @@ -6,16 +6,4 @@ usr/lib/python3.*/dist-packages/probert-*.egg-info/top_level.txt usr/lib/python3.*/dist-packages/probert/__init__.py usr/lib/python3.*/dist-packages/probert/log.py usr/lib/python3.*/dist-packages/probert/prober.py -usr/lib/python3.*/dist-packages/probert/tests/__init__.py -usr/lib/python3.*/dist-packages/probert/tests/data/dasdd.view -usr/lib/python3.*/dist-packages/probert/tests/data/dasde.view -usr/lib/python3.*/dist-packages/probert/tests/data/fake_probe_all.json -usr/lib/python3.*/dist-packages/probert/tests/fakes.py -usr/lib/python3.*/dist-packages/probert/tests/helpers.py -usr/lib/python3.*/dist-packages/probert/tests/test_dasd.py -usr/lib/python3.*/dist-packages/probert/tests/test_lvm.py -usr/lib/python3.*/dist-packages/probert/tests/test_multipath.py -usr/lib/python3.*/dist-packages/probert/tests/test_prober.py -usr/lib/python3.*/dist-packages/probert/tests/test_storage.py -usr/lib/python3.*/dist-packages/probert/tests/test_utils.py usr/lib/python3.*/dist-packages/probert/utils.py diff --git a/debian/probert-network.install b/debian/probert-network.install index c0bbfdb..770c14c 100644 --- a/debian/probert-network.install +++ b/debian/probert-network.install @@ -1,5 +1,8 @@ -usr/lib/python3.*/dist-packages/probert/_nl80211.*.so -usr/lib/python3.*/dist-packages/probert/_nl80211module.c -usr/lib/python3.*/dist-packages/probert/_rtnetlink.*.so -usr/lib/python3.*/dist-packages/probert/_rtnetlinkmodule.c usr/lib/python3.*/dist-packages/probert/network.py +usr/lib/python3.*/dist-packages/probert/nl80211.py +usr/lib/python3.*/dist-packages/probert/rtnetlink/addr.py +usr/lib/python3.*/dist-packages/probert/rtnetlink/cache.py +usr/lib/python3.*/dist-packages/probert/rtnetlink/route.py +usr/lib/python3.*/dist-packages/probert/rtnetlink/link.py +usr/lib/python3.*/dist-packages/probert/rtnetlink/listener.py +usr/lib/python3.*/dist-packages/probert/rtnetlink/__init__.py diff --git a/debian/probert-storage.install b/debian/probert-storage.install index edf357c..2e6ad76 100644 --- a/debian/probert-storage.install +++ b/debian/probert-storage.install @@ -5,6 +5,8 @@ usr/lib/python3.*/dist-packages/probert/filesystem.py usr/lib/python3.*/dist-packages/probert/lvm.py usr/lib/python3.*/dist-packages/probert/mount.py usr/lib/python3.*/dist-packages/probert/multipath.py +usr/lib/python3.*/dist-packages/probert/nvme.py +usr/lib/python3.*/dist-packages/probert/os.py usr/lib/python3.*/dist-packages/probert/raid.py usr/lib/python3.*/dist-packages/probert/storage.py usr/lib/python3.*/dist-packages/probert/zfs.py diff --git a/probert/_nl80211module.c b/probert/_nl80211module.c deleted file mode 100644 index 0bba90d..0000000 --- a/probert/_nl80211module.c +++ /dev/null @@ -1,758 +0,0 @@ -#define PY_SSIZE_T_CLEAN -#include -#include -#include - -#include -#include - -#include -#include - -#define NL_CB_me NL_CB_DEFAULT - -struct Listener { - PyObject_HEAD - PyObject *observer; - struct nl_sock* event_sock; - struct nl_sock* genl_sock; - PyObject *exc_typ, *exc_val, *exc_tb; - int err; - int nl80211_id; -}; - - -static void -listener_dealloc(PyObject *self) { - struct Listener* v = (struct Listener*)self; - PyObject_GC_UnTrack(v); - Py_CLEAR(v->observer); - Py_CLEAR(v->exc_typ); - Py_CLEAR(v->exc_val); - Py_CLEAR(v->exc_tb); - nl_socket_free(v->event_sock); - PyObject_GC_Del(v); -} - -static int -listener_traverse(PyObject *self, visitproc visit, void *arg) -{ - struct Listener* v = (struct Listener*)self; - Py_VISIT(v->observer); - Py_VISIT(v->exc_typ); - Py_VISIT(v->exc_val); - Py_VISIT(v->exc_tb); - return 0; -} - -static PyTypeObject ListenerType; - -static int ack_handler(struct nl_msg *msg, void *arg) { - int *err = arg; - *err = 0; - return NL_STOP; -} - -static int finish_handler(struct nl_msg *msg, void *arg) { - int *ret = arg; - *ret = 0; - return NL_SKIP; -} - -static int error_handler(struct sockaddr_nl *nla, struct nlmsgerr *err, - void *arg) { - int *ret = arg; - *ret = err->error; - return NL_STOP; -} - -static int no_seq_check(struct nl_msg *msg, void *arg) { return NL_OK; } - -struct nl80211_multicast_ids { - int mlme_id; - int scan_id; -}; - -static int family_handler(struct nl_msg *msg, void *arg) -{ - struct nl80211_multicast_ids *res = arg; - struct nlattr *tb[CTRL_ATTR_MAX + 1]; - struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg)); - struct nlattr *mcgrp; - int i; - - nla_parse(tb, CTRL_ATTR_MAX, genlmsg_attrdata(gnlh, 0), - genlmsg_attrlen(gnlh, 0), NULL); - if (!tb[CTRL_ATTR_MCAST_GROUPS]) - return NL_SKIP; - - nla_for_each_nested(mcgrp, tb[CTRL_ATTR_MCAST_GROUPS], i) { - struct nlattr *tb2[CTRL_ATTR_MCAST_GRP_MAX + 1]; - char *name; - int len; - nla_parse(tb2, CTRL_ATTR_MCAST_GRP_MAX, nla_data(mcgrp), nla_len(mcgrp), - NULL); - if (!tb2[CTRL_ATTR_MCAST_GRP_NAME] || !tb2[CTRL_ATTR_MCAST_GRP_ID]) { - continue; - } - name = nla_data(tb2[CTRL_ATTR_MCAST_GRP_NAME]); - len = nla_len(tb2[CTRL_ATTR_MCAST_GRP_NAME]); - if (strncmp(name, "scan", len) == 0) { - res->scan_id = nla_get_u32(tb2[CTRL_ATTR_MCAST_GRP_ID]); - } - if (strncmp(name, "mlme", len) == 0) { - res->mlme_id = nla_get_u32(tb2[CTRL_ATTR_MCAST_GRP_ID]); - } - }; - - return NL_SKIP; -} - -static int send_and_recv( - struct nl_sock *sock, - struct nl_msg *msg, - int (*valid_handler)(struct nl_msg *, void *), - void *valid_data) -{ - struct nl_cb *cb; - int err = -ENOMEM; - - cb = nl_cb_alloc(NL_CB_me); - if (!cb) - goto out; - - err = nl_send_auto(sock, msg); - if (err < 0) - goto out; - - err = 1; - - nl_cb_err(cb, NL_CB_CUSTOM, error_handler, &err); - nl_cb_set(cb, NL_CB_FINISH, NL_CB_CUSTOM, finish_handler, &err); - nl_cb_set(cb, NL_CB_ACK, NL_CB_CUSTOM, ack_handler, &err); - - if (valid_handler) { - nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, valid_handler, valid_data); - } - - while (err > 0) { - int res = nl_recvmsgs(sock, cb); - if (res < 0 && err > 0) { - err = res; - } - } - out: - nl_cb_put(cb); - nlmsg_free(msg); - return err; -} - -static int nl_get_multicast_ids( - struct nl_sock *genl_sock, - struct nl80211_multicast_ids *res) -{ - struct nl_msg *msg; - int ret = -1; - - msg = nlmsg_alloc(); - if (!msg) - return -ENOMEM; - genlmsg_put(msg, 0, 0, genl_ctrl_resolve(genl_sock, "nlctrl"), 0, 0, - CTRL_CMD_GETFAMILY, 0); - NLA_PUT_STRING(msg, CTRL_ATTR_FAMILY_NAME, "nl80211"); - - ret = send_and_recv(genl_sock, msg, family_handler, res); - msg = NULL; - - nla_put_failure: - nlmsg_free(msg); - return ret; -} - -static const char *nl80211_command_to_string(enum nl80211_commands cmd) { -#define C2S(x) \ - case x: \ - return &#x[12] - switch (cmd) { - C2S(NL80211_CMD_UNSPEC); - C2S(NL80211_CMD_GET_WIPHY); - C2S(NL80211_CMD_SET_WIPHY); - C2S(NL80211_CMD_NEW_WIPHY); - C2S(NL80211_CMD_DEL_WIPHY); - C2S(NL80211_CMD_GET_INTERFACE); - C2S(NL80211_CMD_SET_INTERFACE); - C2S(NL80211_CMD_NEW_INTERFACE); - C2S(NL80211_CMD_DEL_INTERFACE); - C2S(NL80211_CMD_GET_KEY); - C2S(NL80211_CMD_SET_KEY); - C2S(NL80211_CMD_NEW_KEY); - C2S(NL80211_CMD_DEL_KEY); - C2S(NL80211_CMD_GET_BEACON); - C2S(NL80211_CMD_SET_BEACON); - C2S(NL80211_CMD_START_AP); - C2S(NL80211_CMD_STOP_AP); - C2S(NL80211_CMD_GET_STATION); - C2S(NL80211_CMD_SET_STATION); - C2S(NL80211_CMD_NEW_STATION); - C2S(NL80211_CMD_DEL_STATION); - C2S(NL80211_CMD_GET_MPATH); - C2S(NL80211_CMD_SET_MPATH); - C2S(NL80211_CMD_NEW_MPATH); - C2S(NL80211_CMD_DEL_MPATH); - C2S(NL80211_CMD_SET_BSS); - C2S(NL80211_CMD_SET_REG); - C2S(NL80211_CMD_REQ_SET_REG); - C2S(NL80211_CMD_GET_MESH_CONFIG); - C2S(NL80211_CMD_SET_MESH_CONFIG); - C2S(NL80211_CMD_SET_MGMT_EXTRA_IE); - C2S(NL80211_CMD_GET_REG); - C2S(NL80211_CMD_GET_SCAN); - C2S(NL80211_CMD_TRIGGER_SCAN); - C2S(NL80211_CMD_NEW_SCAN_RESULTS); - C2S(NL80211_CMD_SCAN_ABORTED); - C2S(NL80211_CMD_REG_CHANGE); - C2S(NL80211_CMD_AUTHENTICATE); - C2S(NL80211_CMD_ASSOCIATE); - C2S(NL80211_CMD_DEAUTHENTICATE); - C2S(NL80211_CMD_DISASSOCIATE); - C2S(NL80211_CMD_MICHAEL_MIC_FAILURE); - C2S(NL80211_CMD_REG_BEACON_HINT); - C2S(NL80211_CMD_JOIN_IBSS); - C2S(NL80211_CMD_LEAVE_IBSS); - C2S(NL80211_CMD_TESTMODE); - C2S(NL80211_CMD_CONNECT); - C2S(NL80211_CMD_ROAM); - C2S(NL80211_CMD_DISCONNECT); - C2S(NL80211_CMD_SET_WIPHY_NETNS); - C2S(NL80211_CMD_GET_SURVEY); - C2S(NL80211_CMD_NEW_SURVEY_RESULTS); - C2S(NL80211_CMD_SET_PMKSA); - C2S(NL80211_CMD_DEL_PMKSA); - C2S(NL80211_CMD_FLUSH_PMKSA); - C2S(NL80211_CMD_REMAIN_ON_CHANNEL); - C2S(NL80211_CMD_CANCEL_REMAIN_ON_CHANNEL); - C2S(NL80211_CMD_SET_TX_BITRATE_MASK); - C2S(NL80211_CMD_REGISTER_FRAME); - C2S(NL80211_CMD_FRAME); - C2S(NL80211_CMD_FRAME_TX_STATUS); - C2S(NL80211_CMD_SET_POWER_SAVE); - C2S(NL80211_CMD_GET_POWER_SAVE); - C2S(NL80211_CMD_SET_CQM); - C2S(NL80211_CMD_NOTIFY_CQM); - C2S(NL80211_CMD_SET_CHANNEL); - C2S(NL80211_CMD_SET_WDS_PEER); - C2S(NL80211_CMD_FRAME_WAIT_CANCEL); - C2S(NL80211_CMD_JOIN_MESH); - C2S(NL80211_CMD_LEAVE_MESH); - C2S(NL80211_CMD_UNPROT_DEAUTHENTICATE); - C2S(NL80211_CMD_UNPROT_DISASSOCIATE); - C2S(NL80211_CMD_NEW_PEER_CANDIDATE); - C2S(NL80211_CMD_GET_WOWLAN); - C2S(NL80211_CMD_SET_WOWLAN); - C2S(NL80211_CMD_START_SCHED_SCAN); - C2S(NL80211_CMD_STOP_SCHED_SCAN); - C2S(NL80211_CMD_SCHED_SCAN_RESULTS); - C2S(NL80211_CMD_SCHED_SCAN_STOPPED); - C2S(NL80211_CMD_SET_REKEY_OFFLOAD); - C2S(NL80211_CMD_PMKSA_CANDIDATE); - C2S(NL80211_CMD_TDLS_OPER); - C2S(NL80211_CMD_TDLS_MGMT); - C2S(NL80211_CMD_UNEXPECTED_FRAME); - C2S(NL80211_CMD_PROBE_CLIENT); - C2S(NL80211_CMD_REGISTER_BEACONS); - C2S(NL80211_CMD_UNEXPECTED_4ADDR_FRAME); - C2S(NL80211_CMD_SET_NOACK_MAP); - C2S(NL80211_CMD_CH_SWITCH_NOTIFY); - C2S(NL80211_CMD_START_P2P_DEVICE); - C2S(NL80211_CMD_STOP_P2P_DEVICE); - C2S(NL80211_CMD_CONN_FAILED); - C2S(NL80211_CMD_SET_MCAST_RATE); - C2S(NL80211_CMD_SET_MAC_ACL); - C2S(NL80211_CMD_RADAR_DETECT); - C2S(NL80211_CMD_GET_PROTOCOL_FEATURES); - C2S(NL80211_CMD_UPDATE_FT_IES); - C2S(NL80211_CMD_FT_EVENT); - C2S(NL80211_CMD_CRIT_PROTOCOL_START); - C2S(NL80211_CMD_CRIT_PROTOCOL_STOP); - C2S(NL80211_CMD_GET_COALESCE); - C2S(NL80211_CMD_SET_COALESCE); - C2S(NL80211_CMD_CHANNEL_SWITCH); - C2S(NL80211_CMD_VENDOR); - C2S(NL80211_CMD_SET_QOS_MAP); - default: - return "NL80211_CMD_UNKNOWN"; - } -#undef C2S -} - -static int observe_wlan_event(struct Listener* listener, int ifindex, const char* cmd, PyObject* extra) -{ - if (listener->exc_typ != NULL || listener->observer == Py_None) { - return NL_STOP; - } - PyObject *arg = PyDict_New(); - PyObject *ob = NULL; - - if (arg == NULL) { - goto exit; - } - - ob = PyUnicode_FromString(cmd); - if (ob == NULL || PyDict_SetItemString(arg, "cmd", ob) < 0) { - goto exit; - } - Py_DECREF(ob); - ob = NULL; - - ob = PyLong_FromLong(ifindex); - if (ob == NULL || PyDict_SetItemString(arg, "ifindex", ob) < 0) { - goto exit; - } - Py_DECREF(ob); - ob = NULL; - - if (extra != NULL) { - PyDict_Update(arg, extra); - } - - PyObject *r = PyObject_CallMethod(listener->observer, "wlan_event", "O", arg); - Py_XDECREF(r); - exit: - Py_XDECREF(arg); - Py_XDECREF(ob); - if (PyErr_Occurred()) { - PyErr_Fetch(&listener->exc_typ, &listener->exc_val, &listener->exc_tb); - return NL_STOP; - } - - return NL_SKIP; -} - -static int nl80211_trigger_scan(struct Listener *listener, int ifidx) { - struct nl_msg *msg = NULL; - struct nl_msg *ssids = NULL; - int r; - - struct nl_sock *genl_sock = nl_socket_alloc(); - if (genl_sock == NULL) { - r = -1; - goto nla_put_failure; - } - r = genl_connect(genl_sock); - if (r < 0) { - goto nla_put_failure; - } - - msg = nlmsg_alloc(); - if (!msg) { - goto nla_put_failure; - } - genlmsg_put(msg, 0, 0, listener->nl80211_id, 0, 0, NL80211_CMD_TRIGGER_SCAN, 0); - NLA_PUT_U32(msg, NL80211_ATTR_IFINDEX, ifidx); - ssids = nlmsg_alloc(); - if (!ssids) { - goto nla_put_failure; - } - NLA_PUT(ssids, 1, 0, ""); - nla_put_nested(msg, NL80211_ATTR_SCAN_SSIDS, ssids); - - r = send_and_recv(genl_sock, msg, NULL, NULL); - msg = NULL; - nla_put_failure: - nlmsg_free(msg); - nlmsg_free(ssids); - nl_socket_free(genl_sock); - return r; -} - -static char *nl80211_get_ie(char *ies, size_t ies_len, char ie) { - /* - * It is important to work with unsigned here because the length field of - * an IE is one byte. If the length is > 0x7F and we're working with signed - * chars, we will interpret it as a negative length, causing various issues - * like infinite loops. - */ - unsigned char *end, *pos; - - if (ies == NULL) - return NULL; - - pos = (unsigned char *)ies; - end = (unsigned char *)ies + ies_len; - - while (pos + 1 < end) { - if (pos + 2 + pos[1] > end) - break; - if (pos[0] == ie) - return (char *)pos; - pos += 2 + pos[1]; - } - - return NULL; -} - -struct scan_handler_params { - PyObject* ssid_list; - int only_connected; -}; - -static void extract_ssid(struct nlattr *data, struct scan_handler_params *p) -{ - struct nlattr *bss[NL80211_BSS_MAX + 1]; - static struct nla_policy bss_policy[NL80211_BSS_MAX + 1] = { - [NL80211_BSS_INFORMATION_ELEMENTS] = {}, - [NL80211_BSS_STATUS] = {.type = NLA_U32}, [NL80211_BSS_BSSID] = {}, - }; - if (nla_parse_nested(bss, NL80211_BSS_MAX, data, bss_policy)) - return; - char *cstatus = "no status"; - if (bss[NL80211_BSS_STATUS]) { - int status = -1; - status = nla_get_u32(bss[NL80211_BSS_STATUS]); - switch (status) { - case NL80211_BSS_STATUS_ASSOCIATED: - cstatus = "Connected"; - break; - case NL80211_BSS_STATUS_AUTHENTICATED: - cstatus = "Authenticated"; - break; - case NL80211_BSS_STATUS_IBSS_JOINED: - cstatus = "Joined"; - break; - } - } else if (p->only_connected) { - return; - } - char *ie = nla_data(bss[NL80211_BSS_INFORMATION_ELEMENTS]); - size_t ie_len = nla_len(bss[NL80211_BSS_INFORMATION_ELEMENTS]); - char *ssid = nl80211_get_ie(ie, ie_len, 0); - - if (ssid == NULL) { - /* - * LP: #2104087 For reasons yet to be determined, the SSID information - * element (aka. IE) can sometimes be completely missing. - * We have speculated that it could be related to hidden SSIDs but - * testing showed that having an SSID information element with size 0 - * is a thing. - */ - return; - } - - ssize_t ssid_len = (ssize_t)ssid[1]; - PyObject* v = Py_BuildValue("(y#s)", ssid + 2, ssid_len, cstatus); - if (v == NULL) { - Py_CLEAR(p->ssid_list); - return; - } - if (PyList_Append(p->ssid_list, v) < 0) { - Py_CLEAR(p->ssid_list); - } - Py_DECREF(v); -} - -static int nl80211_scan_handler(struct nl_msg *msg, void *arg) { - struct scan_handler_params *p = (struct scan_handler_params *)arg; - struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg)); - struct nlattr *tb[NL80211_ATTR_MAX + 1]; - int ifidx = -1; - - nla_parse(tb, NL80211_ATTR_MAX, genlmsg_attrdata(gnlh, 0), - genlmsg_attrlen(gnlh, 0), NULL); - - if (tb[NL80211_ATTR_IFINDEX]) { - ifidx = nla_get_u32(tb[NL80211_ATTR_IFINDEX]); - } - - if (ifidx < 0) { - return NL_SKIP; - } - - if (tb[NL80211_ATTR_BSS]) { - extract_ssid(tb[NL80211_ATTR_BSS], p); - } - - if (p->ssid_list == NULL) { - return NL_STOP; - } - - return NL_SKIP; -} - -static PyObject* -dump_scan_results(struct Listener* listener, int ifidx, int only_connected) -{ - struct nl_msg *msg = NULL; - struct scan_handler_params p = { .ssid_list = NULL }; - p.only_connected = only_connected; - struct nl_sock *genl_sock = nl_socket_alloc(); - int r; - - if (genl_sock == NULL) { - PyErr_SetString(PyExc_MemoryError, "nl_socket_alloc failed"); - goto nla_put_failure; - } - r = genl_connect(genl_sock); - if (r < 0) { - PyErr_Format(PyExc_MemoryError, "genl_connect failed %d", r); - goto nla_put_failure; - } - p.ssid_list = PyList_New(0); - if (p.ssid_list == NULL) { - goto nla_put_failure; - } - - msg = nlmsg_alloc(); - if (!msg) { - goto nla_put_failure; - } - genlmsg_put(msg, 0, 0, listener->nl80211_id, 0, NLM_F_DUMP, NL80211_CMD_GET_SCAN, 0); - NLA_PUT_U32(msg, NL80211_ATTR_IFINDEX, ifidx); - - send_and_recv(genl_sock, msg, nl80211_scan_handler, &p); - msg = NULL; - nla_put_failure: - nlmsg_free(msg); - nl_socket_free(genl_sock); - return p.ssid_list; -} - -static int event_handler(struct nl_msg *msg, void *arg) -{ - struct Listener* listener = (struct Listener*)arg; - struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg)); - struct nlattr *tb[NL80211_ATTR_MAX + 1]; - int ifidx = -1; - PyObject* extra = NULL; - int r; - - nla_parse(tb, NL80211_ATTR_MAX, genlmsg_attrdata(gnlh, 0), - genlmsg_attrlen(gnlh, 0), NULL); - - if (tb[NL80211_ATTR_IFINDEX]) { - ifidx = nla_get_u32(tb[NL80211_ATTR_IFINDEX]); - } - - if (ifidx > 0) { - if (gnlh->cmd == NL80211_CMD_NEW_SCAN_RESULTS) { - PyObject* ssids = dump_scan_results(listener, ifidx, 0); - if (ssids == NULL) { - return NL_STOP; - } - extra = Py_BuildValue("{sO}", "ssids", ssids); - } - - if (gnlh->cmd == NL80211_CMD_ASSOCIATE || gnlh->cmd == NL80211_CMD_NEW_INTERFACE) { - PyObject* ssids = dump_scan_results(listener, ifidx, 1); - if (ssids == NULL) { - return NL_STOP; - } - extra = Py_BuildValue("{sO}", "ssids", ssids); - } - } - - r = observe_wlan_event(listener, ifidx, nl80211_command_to_string(gnlh->cmd), extra); - Py_XDECREF(extra); - return r; -} - - -static PyObject * -listener_new(PyTypeObject *type, PyObject *args, PyObject *kw) -{ - struct nl_sock *event_sock, *genl_sock; - struct nl_cb *event_cb; - - struct Listener* listener = (struct Listener*)type->tp_alloc(type, 0); - - event_cb = nl_cb_alloc(NL_CB_me); - nl_cb_err(event_cb, NL_CB_CUSTOM, error_handler, &listener->err); - nl_cb_set(event_cb, NL_CB_FINISH, NL_CB_CUSTOM, finish_handler, &listener->err); - nl_cb_set(event_cb, NL_CB_ACK, NL_CB_CUSTOM, ack_handler, &listener->err); - nl_cb_set(event_cb, NL_CB_SEQ_CHECK, NL_CB_CUSTOM, no_seq_check, &listener->err); - nl_cb_set(event_cb, NL_CB_VALID, NL_CB_CUSTOM, event_handler, listener); - - event_sock = nl_socket_alloc_cb(event_cb); - if (event_sock == NULL) { - PyErr_SetString(PyExc_MemoryError, "nl_socket_alloc_cb"); - return NULL; - } - - genl_sock = nl_socket_alloc(); - if (genl_sock == NULL) { - nl_socket_free(event_sock); - PyErr_SetString(PyExc_MemoryError, "nl_socket_alloc"); - return NULL; - } - // XXX is this really needed? - nl_socket_set_cb(genl_sock, nl_cb_alloc(NL_CB_me)); - - listener->event_sock = event_sock; - listener->genl_sock = genl_sock; - - Py_INCREF(Py_None); - listener->observer = Py_None; - - return (PyObject*)listener; -} - -static int -listener_init(PyObject *self, PyObject *args, PyObject *kw) -{ - PyObject* observer; - - char *kwlist[] = {"observer", 0}; - - if (!PyArg_ParseTupleAndKeywords(args, kw, "O:listener", kwlist, &observer)) - return -1; - - struct Listener* listener = (struct Listener*)self; - - Py_CLEAR(listener->observer); - Py_INCREF(observer); - listener->observer = observer; - - return 0; -} - -static PyObject* -maybe_restore(struct Listener* listener) { - if (listener->exc_typ != NULL) { - PyErr_Restore(listener->exc_typ, listener->exc_val, listener->exc_tb); - listener->exc_typ = listener->exc_val = listener->exc_tb = NULL; - return NULL; - } - if (listener->err != 0) { - PyErr_Format(PyExc_RuntimeError, "random netlink error: %d", listener->err); - } - Py_RETURN_NONE; -} - -static PyObject* -listener_start(PyObject *self, PyObject* args) -{ - int r; - struct nl80211_multicast_ids ids; - struct Listener* listener = (struct Listener*)self; - - r = genl_connect(listener->genl_sock); - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "genl_connect failed: %d", r); - return NULL; - } - listener->nl80211_id = genl_ctrl_resolve(listener->genl_sock, "nl80211"); - r = nl_get_multicast_ids(listener->genl_sock, &ids); - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "nl_get_multicast_ids failed: %d", r); - return NULL; - } - - r = genl_connect(listener->event_sock); - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "genl_connect failed: %d", r); - return NULL; - } - r = nl_socket_set_nonblocking(listener->event_sock); - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "nl_socket_set_nonblocking failed: %d", r); - return NULL; - } - r = nl_socket_add_memberships(listener->event_sock, ids.mlme_id, ids.scan_id, 0); - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "nl_socket_add_memberships: %d", r); - return NULL; - } - - // Request a dump of all wlan interfaces to get us started. - struct nl_msg *msg; - msg = nlmsg_alloc(); - if (!msg) - return NULL; - genlmsg_put(msg, 0, 0, listener->nl80211_id, 0, NLM_F_DUMP, NL80211_CMD_GET_INTERFACE, - 0); - - send_and_recv(listener->genl_sock, msg, event_handler, listener); - - return maybe_restore(listener); -} - -static PyObject* -listener_fileno(PyObject *self, PyObject* args) -{ - struct Listener* listener = (struct Listener*)self; - return PyLong_FromLong(nl_socket_get_fd(listener->event_sock)); -} - -static PyObject* -listener_data_ready(PyObject *self, PyObject* args) -{ - struct Listener* listener = (struct Listener*)self; - - nl_recvmsgs_default(listener->event_sock); - - return maybe_restore(listener); -} - -static PyObject* -listener_trigger_scan(PyObject *self, PyObject* args, PyObject* kw) -{ - struct Listener* listener = (struct Listener*)self; - long ifindex; - - char *kwlist[] = {"ifindex", 0}; - - if (!PyArg_ParseTupleAndKeywords(args, kw, "i:listener", kwlist, &ifindex)) - return NULL; - - int r = 0; - r = nl80211_trigger_scan(listener, ifindex); - - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "triggering scan failed %d\n", r); - return NULL; - } - - Py_RETURN_NONE; -} - -static PyMethodDef ListenerMethods[] = { - {"start", listener_start, METH_NOARGS, "XXX."}, - {"fileno", listener_fileno, METH_NOARGS, "XXX."}, - {"data_ready", listener_data_ready, METH_VARARGS, "XXX."}, - {"trigger_scan", (PyCFunction)listener_trigger_scan, METH_VARARGS|METH_KEYWORDS, "XXX."}, - {}, -}; - -static PyTypeObject ListenerType = { - .ob_base = PyVarObject_HEAD_INIT(&PyType_Type, 0) - .tp_name = "_nl80211.listener", - .tp_basicsize = sizeof(struct Listener), - - .tp_dealloc = listener_dealloc, - .tp_new = listener_new, - .tp_init = listener_init, - .tp_traverse = listener_traverse, - - .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, - .tp_methods = ListenerMethods, -}; - -static struct PyModuleDef nl80211_module = { - PyModuleDef_HEAD_INIT, - "_nl80211", -}; - -PyMODINIT_FUNC -PyInit__nl80211(void) -{ - PyObject *m = PyModule_Create(&nl80211_module); - - if (m == NULL) - return NULL; - - if (PyType_Ready(&ListenerType) < 0) - return NULL; - - PyModule_AddObject(m, "listener", (PyObject *)&ListenerType); - - return m; -} diff --git a/probert/_rtnetlinkmodule.c b/probert/_rtnetlinkmodule.c deleted file mode 100644 index fed2d1f..0000000 --- a/probert/_rtnetlinkmodule.c +++ /dev/null @@ -1,513 +0,0 @@ -#define PY_SSIZE_T_CLEAN -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#define NL_CB_me NL_CB_DEFAULT - -static char *act2str(int act) { -#define C2S(x) \ - case x: \ - return &#x[7] - switch (act) { - C2S(NL_ACT_UNSPEC); - C2S(NL_ACT_NEW); - C2S(NL_ACT_DEL); - C2S(NL_ACT_GET); - C2S(NL_ACT_SET); - C2S(NL_ACT_CHANGE); - default: - return "???"; - } -#undef C2S -} - -struct Listener { - PyObject_HEAD - struct nl_cache_mngr *mngr; - struct nl_cache *link_cache; - struct nl_cache *route_cache; - PyObject *observer; - PyObject *exc_typ, *exc_val, *exc_tb; -}; - -struct _clear_routes_arg { - struct Listener *listener; - int ifindex; -}; - -static void observe_route_change( - int act, - struct rtnl_route *route, - struct Listener* listener); - -static void _clear_routes(struct nl_object *ob, void *data) { - struct _clear_routes_arg* arg = (struct _clear_routes_arg*)data; - struct rtnl_route* route = (struct rtnl_route*)ob; - - if (rtnl_route_get_nnexthops(route) > 0) { - // Bit cheaty to ignore multipath but.... - struct rtnl_nexthop* nh = rtnl_route_nexthop_n(route, 0); - if (rtnl_route_nh_get_ifindex(nh) == arg->ifindex) { - observe_route_change(NL_ACT_DEL, route, arg->listener); - nl_cache_remove(ob); - } - } -} - -static void observe_link_change( - int act, - struct rtnl_link *old_link, - struct rtnl_link *link, - struct Listener* listener) -{ - if (listener->exc_typ != NULL || listener->observer == Py_None) { - return; - } - PyObject *data; - - struct _clear_routes_arg clear_routes_arg; - int is_vlan, ifindex; - unsigned int flags; - - if (act == NL_ACT_DEL) { - link = old_link; - } - - is_vlan = rtnl_link_is_vlan(link); - ifindex = rtnl_link_get_ifindex(link); - flags = rtnl_link_get_flags(link); - if (!(flags & IFF_UP)) { - if (old_link && (rtnl_link_get_flags(old_link) & IFF_UP)) { - clear_routes_arg.ifindex = ifindex; - clear_routes_arg.listener = listener; - nl_cache_foreach(listener->route_cache, _clear_routes, &clear_routes_arg); - } - } - - data = Py_BuildValue( - "{si sI sI si sN}", - "ifindex", ifindex, - "flags", flags, - "arptype", rtnl_link_get_arptype(link), - "family", rtnl_link_get_family(link), - "is_vlan", PyBool_FromLong(is_vlan)); - if (data == NULL) { - goto exit; - } - if (rtnl_link_get_name(link) != NULL) { - PyObject *ob = PyBytes_FromString(rtnl_link_get_name(link)); - if (ob == NULL || PyDict_SetItemString(data, "name", ob) < 0) { - Py_XDECREF(ob); - goto exit; - } - Py_DECREF(ob); - } - if (is_vlan) { - PyObject* v; - v = PyLong_FromLong(rtnl_link_vlan_get_id(link)); - if (v == NULL || PyDict_SetItemString(data, "vlan_id", v) < 0) { - Py_XDECREF(v); - goto exit; - } - Py_DECREF(v); - v = PyLong_FromLong(rtnl_link_get_link(link)); - if (v == NULL || PyDict_SetItemString(data, "vlan_link", v) < 0) { - Py_XDECREF(v); - goto exit; - } - Py_DECREF(v); - } - PyObject *r = PyObject_CallMethod(listener->observer, "link_change", "sO", act2str(act), data); - Py_XDECREF(r); - - exit: - Py_XDECREF(data); - if (PyErr_Occurred()) { - PyErr_Fetch(&listener->exc_typ, &listener->exc_val, &listener->exc_tb); - } -} - -static void _cb_link(struct nl_cache *cache, struct nl_object *old, struct nl_object *new, uint64_t diff, int act, - void *data) { - observe_link_change(act, (struct rtnl_link *)old, (struct rtnl_link *)new, (struct Listener*)data); -} - -static void _e_link(struct nl_object *ob, void *data) { - observe_link_change(NL_ACT_NEW, NULL, (struct rtnl_link *)ob, (struct Listener*)data); -} - -static void observe_addr_change( - int act, - struct rtnl_addr *addr, - struct Listener* listener) -{ - if (listener->exc_typ != NULL || listener->observer == Py_None) { - return; - } - PyObject *data; - data = Py_BuildValue( - "{si sI si si}", - "ifindex", rtnl_addr_get_ifindex(addr), - "flags", rtnl_addr_get_flags(addr), - "family", rtnl_addr_get_family(addr), - "scope", rtnl_addr_get_scope(addr)); - if (data == NULL) { - goto exit; - } - struct nl_addr *local = rtnl_addr_get_local(addr); - if (local != NULL) { - char buf[100]; - PyObject *ob = PyBytes_FromString(nl_addr2str(local, buf, 100)); - if (ob == NULL || PyDict_SetItemString(data, "local", ob) < 0) { - Py_XDECREF(ob); - goto exit; - } - Py_DECREF(ob); - } - PyObject *r = PyObject_CallMethod(listener->observer, "addr_change", "sO", act2str(act), data); - Py_XDECREF(r); - - exit: - Py_XDECREF(data); - if (PyErr_Occurred()) { - PyErr_Fetch(&listener->exc_typ, &listener->exc_val, &listener->exc_tb); - } -} - -static void _cb_addr(struct nl_cache *cache, struct nl_object *ob, int act, - void *data) { - observe_addr_change(act, (struct rtnl_addr *)ob, (struct Listener*)data); -} - -static void _e_addr(struct nl_object *ob, void *data) { - observe_addr_change(NL_ACT_NEW, (struct rtnl_addr *)ob, (struct Listener*)data); -} - -static void observe_route_change( - int act, - struct rtnl_route *route, - struct Listener* listener) -{ - if (listener->exc_typ != NULL || listener->observer == Py_None) { - return; - } - PyObject *data; - char *cdst; - char dstbuf[64]; - struct nl_addr* dst = rtnl_route_get_dst(route); - if (dst == NULL || nl_addr_get_len(dst) == 0) { - cdst = "default"; - } else { - cdst = nl_addr2str(dst, dstbuf, sizeof(dstbuf)); - } - - int ifindex = -1; - int nnexthops = rtnl_route_get_nnexthops(route); - if (nnexthops > 0) { - // Bit cheaty to ignore multipath but.... - struct rtnl_nexthop* nh = rtnl_route_nexthop_n(route, 0); - ifindex = rtnl_route_nh_get_ifindex(nh); - } - data = Py_BuildValue( - "{sB sB sI sy si}", - "family", rtnl_route_get_family(route), - "type", rtnl_route_get_type(route), - "table", rtnl_route_get_table(route), - "dst", cdst, - "ifindex", ifindex); - if (data == NULL) { - goto exit; - } - PyObject *r = PyObject_CallMethod(listener->observer, "route_change", "sO", act2str(act), data); - Py_XDECREF(r); - - exit: - Py_XDECREF(data); - if (PyErr_Occurred()) { - PyErr_Fetch(&listener->exc_typ, &listener->exc_val, &listener->exc_tb); - } -} - -static void _cb_route(struct nl_cache *cache, struct nl_object *ob, int act, - void *data) { - observe_route_change(act, (struct rtnl_route *)ob, (struct Listener*)data); -} - -static void _e_route(struct nl_object *ob, void *data) { - observe_route_change(NL_ACT_NEW, (struct rtnl_route *)ob, (struct Listener*)data); -} - -static void -listener_dealloc(PyObject *self) { - struct Listener* v = (struct Listener*)self; - PyObject_GC_UnTrack(v); - Py_CLEAR(v->observer); - nl_cache_mngr_free(v->mngr); - Py_CLEAR(v->exc_typ); - Py_CLEAR(v->exc_val); - Py_CLEAR(v->exc_tb); - PyObject_GC_Del(v); -} - -static int -listener_traverse(PyObject *self, visitproc visit, void *arg) -{ - struct Listener* v = (struct Listener*)self; - Py_VISIT(v->observer); - Py_VISIT(v->exc_typ); - Py_VISIT(v->exc_val); - Py_VISIT(v->exc_tb); - return 0; -} - -static PyTypeObject ListenerType; - -static PyObject * -listener_new(PyTypeObject *type, PyObject *args, PyObject *kw) -{ - struct nl_cache_mngr *mngr; - int r; - - r = nl_cache_mngr_alloc(NULL, NETLINK_ROUTE, NL_AUTO_PROVIDE, &mngr); - if (r < 0) { - PyErr_Format(PyExc_MemoryError, "nl_cache_mngr_alloc failed %d", r); - return NULL; - } - - struct Listener* listener = (struct Listener*)type->tp_alloc(type, 0); - - listener->mngr = mngr; - - Py_INCREF(Py_None); - listener->observer = Py_None; - - return (PyObject*)listener; -} - -static int -listener_init(PyObject *self, PyObject *args, PyObject *kw) -{ - PyObject* observer; - - char *kwlist[] = {"observer", 0}; - - if (!PyArg_ParseTupleAndKeywords(args, kw, "O:listener", kwlist, &observer)) - return -1; - - struct Listener* listener = (struct Listener*)self; - - Py_CLEAR(listener->observer); - Py_INCREF(observer); - listener->observer = observer; - - return 0; -} - -static PyObject* -maybe_restore(struct Listener* listener) { - if (listener->exc_typ != NULL) { - PyErr_Restore(listener->exc_typ, listener->exc_val, listener->exc_tb); - listener->exc_typ = listener->exc_val = listener->exc_tb = NULL; - return NULL; - } - Py_RETURN_NONE; -} - -static PyObject* -listener_start(PyObject *self, PyObject* args) -{ - struct nl_cache *addr_cache; - struct Listener* listener = (struct Listener*)self; - int r; - - r = rtnl_link_alloc_cache(NULL, AF_UNSPEC, &listener->link_cache); - if (r < 0) { - PyErr_Format(PyExc_MemoryError, "rtnl_link_alloc_cache failed %d\n", r); - return NULL; - } - - r = nl_cache_mngr_add_cache_v2(listener->mngr, listener->link_cache, _cb_link, listener); - if (r < 0) { - nl_cache_free(listener->link_cache); - listener->link_cache = NULL; - PyErr_Format(PyExc_RuntimeError, "nl_cache_mngr_add_cache failed %d\n", r); - return NULL; - } - - r = rtnl_addr_alloc_cache(NULL, &addr_cache); - if (r < 0) { - PyErr_Format(PyExc_MemoryError, "rtnl_link_alloc_cache failed %d\n", r); - return NULL; - } - - r = nl_cache_mngr_add_cache(listener->mngr, addr_cache, _cb_addr, listener); - if (r < 0) { - nl_cache_free(addr_cache); - PyErr_Format(PyExc_RuntimeError, "nl_cache_mngr_add_cache failed %d\n", r); - return NULL; - } - - r = rtnl_route_alloc_cache(NULL, AF_UNSPEC, 0, &listener->route_cache); - if (r < 0) { - PyErr_Format(PyExc_MemoryError, "rtnl_route_alloc_cache failed %d\n", r); - return NULL; - } - - r = nl_cache_mngr_add_cache(listener->mngr, listener->route_cache, _cb_route, listener); - if (r < 0) { - nl_cache_free(listener->route_cache); - PyErr_Format(PyExc_RuntimeError, "nl_cache_mngr_add_cache failed %d\n", r); - return NULL; - } - - nl_cache_foreach(listener->link_cache, _e_link, self); - nl_cache_foreach(addr_cache, _e_addr, self); - nl_cache_foreach(listener->route_cache, _e_route, self); - - return maybe_restore(listener); -} - -static PyObject* -listener_fileno(PyObject *self, PyObject* args) -{ - struct Listener* listener = (struct Listener*)self; - return PyLong_FromLong(nl_cache_mngr_get_fd(listener->mngr)); -} - -static PyObject* -listener_data_ready(PyObject *self, PyObject* args) -{ - struct Listener* listener = (struct Listener*)self; - nl_cache_mngr_data_ready(listener->mngr); - return maybe_restore(listener); -} - -static PyObject* -listener_set_link_flags(PyObject *self, PyObject* args, PyObject* kw) -{ - int ifindex, flags; - - char *kwlist[] = {"ifindex", "flags", 0}; - - if (!PyArg_ParseTupleAndKeywords(args, kw, "ii:set_link_flags", kwlist, &ifindex, &flags)) - return NULL; - struct Listener* listener = (struct Listener*)self; - struct rtnl_link *link = rtnl_link_get(listener->link_cache, ifindex); - if (link == NULL) { - PyErr_SetString(PyExc_RuntimeError, "link not found"); - return NULL; - } - struct nl_sock* sk = nl_socket_alloc(); - if (sk == NULL) { - rtnl_link_put(link); - PyErr_SetString(PyExc_MemoryError, "nl_socket_alloc() failed"); - return NULL; - } - int r = nl_connect(sk, NETLINK_ROUTE); - if (r < 0) { - rtnl_link_put(link); - nl_socket_free(sk); - PyErr_Format(PyExc_RuntimeError, "nl_connect failed %d", r); - return NULL; - } - rtnl_link_set_flags(link, flags); - r = rtnl_link_change(sk, link, link, 0); - rtnl_link_put(link); - nl_socket_free(sk); - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "rtnl_link_change failed %d", r); - return NULL; - } - Py_RETURN_NONE; -} - -static PyObject* -listener_unset_link_flags(PyObject *self, PyObject* args, PyObject* kw) -{ - int ifindex, flags; - - char *kwlist[] = {"ifindex", "flags", 0}; - - if (!PyArg_ParseTupleAndKeywords(args, kw, "ii:unset_link_flags", kwlist, &ifindex, &flags)) - return NULL; - struct Listener* listener = (struct Listener*)self; - struct rtnl_link *link = rtnl_link_get(listener->link_cache, ifindex); - if (link == NULL) { - PyErr_SetString(PyExc_RuntimeError, "link not found"); - return NULL; - } - struct nl_sock* sk = nl_socket_alloc(); - if (sk == NULL) { - rtnl_link_put(link); - PyErr_SetString(PyExc_MemoryError, "nl_socket_alloc() failed"); - return NULL; - } - int r = nl_connect(sk, NETLINK_ROUTE); - if (r < 0) { - rtnl_link_put(link); - nl_socket_free(sk); - PyErr_Format(PyExc_RuntimeError, "nl_connect failed %d", r); - return NULL; - } - rtnl_link_unset_flags(link, flags); - r = rtnl_link_change(sk, link, link, 0); - rtnl_link_put(link); - nl_socket_free(sk); - if (r < 0) { - PyErr_Format(PyExc_RuntimeError, "rtnl_link_change failed %d", r); - return NULL; - } - Py_RETURN_NONE; -} - -static PyMethodDef ListenerMethods[] = { - {"start", listener_start, METH_NOARGS, "XXX."}, - {"fileno", listener_fileno, METH_NOARGS, "XXX."}, - {"data_ready", listener_data_ready, METH_NOARGS, "XXX."}, - {"set_link_flags", (PyCFunction)listener_set_link_flags, METH_VARARGS|METH_KEYWORDS, "XXX."}, - {"unset_link_flags", (PyCFunction)listener_unset_link_flags, METH_VARARGS|METH_KEYWORDS, "XXX."}, - {}, -}; - -static PyTypeObject ListenerType = { - .ob_base = PyVarObject_HEAD_INIT(&PyType_Type, 0) - .tp_name = "_rtnetlink.listener", - .tp_basicsize = sizeof(struct Listener), - - .tp_dealloc = listener_dealloc, - .tp_new = listener_new, - .tp_init = listener_init, - .tp_traverse = listener_traverse, - - .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, - .tp_methods = ListenerMethods, -}; - -static struct PyModuleDef rtnetlink_module = { - PyModuleDef_HEAD_INIT, - "_rtnetlink", -}; - -PyMODINIT_FUNC -PyInit__rtnetlink(void) -{ - PyObject *m = PyModule_Create(&rtnetlink_module); - - if (m == NULL) - return NULL; - - if (PyType_Ready(&ListenerType) < 0) - return NULL; - - PyModule_AddObject(m, "listener", (PyObject *)&ListenerType); - - return m; -} diff --git a/probert/network.py b/probert/network.py index f7857c3..f8a4f44 100644 --- a/probert/network.py +++ b/probert/network.py @@ -24,15 +24,12 @@ import pyudev +import probert.nl80211 +import probert.rtnetlink.listener from probert.utils import udev_get_attributes log = logging.getLogger('probert.network') -try: - from probert import _rtnetlink -except ImportError as e: - log.warning('Failed import _rtnetlink library modules: %s', e) - # Standard interface flags (net/if.h) IFF_UP = 0x1 # Interface is up. IFF_BROADCAST = 0x2 # Broadcast address valid. @@ -635,8 +632,8 @@ class UdevObserver(NetworkObserver): """Use udev/netlink to observe network changes.""" def __init__(self, receiver=None, *, with_wlan_listener: bool = True): - """ Listen to and handle network events using our _rtnetlink Python - extension. Also optionally use our _nl80211 Python extension for + """ Listen to and handle network events using our rtnetlink Python + module. Also optionally use our nl80211 Python module for scanning when with_wlan_listener is True. """ self._links = {} self.context = pyudev.Context() @@ -647,13 +644,12 @@ def __init__(self, receiver=None, *, with_wlan_listener: bool = True): self._calls = None if with_wlan_listener: - from probert import _nl80211 - self.wlan_listener = _nl80211.listener(self) + self.wlan_listener = probert.nl80211.Listener(self) else: self.wlan_listener = None def start(self): - self.rtlistener = _rtnetlink.listener(self) + self.rtlistener = probert.rtnetlink.listener.Listener(self) with CoalescedCalls(self): self.rtlistener.start() diff --git a/probert/nl80211.py b/probert/nl80211.py new file mode 100644 index 0000000..6bc9012 --- /dev/null +++ b/probert/nl80211.py @@ -0,0 +1,166 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" This module is a pyroute2-based rewrite of _nl80211module.c (which was a C +implementation using libnl). + +NOTE: pyroute2 comes with a "pyroute2.iwutil" module (along with the IW class). +Is is marked experimental but could potentially replace some of the code below. +""" + +from typing import Any + +import pyroute2 +from pyroute2.netlink import NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST +from pyroute2.netlink.nl80211 import (NL80211_BSS_STATUS_ASSOCIATED, + NL80211_BSS_STATUS_AUTHENTICATED, + NL80211_BSS_STATUS_IBSS_JOINED, + NL80211_NAMES, nl80211cmd) + + +def nl_except_to_runtime_err(txt: str): + """The old nl80211 implementation written in C raised RuntimeError + exceptions. Pyroute2, on the other hand, raises pyroute2 exceptions (which + do not inherit from RuntimeError). Use this decorator on nl80211 function + that previously raised RuntimeErrors - to get a similar behavior.""" + def decorator(func): + def inner(*args, **kwargs): + try: + return func(*args, **kwargs) + except pyroute2.netlink.exceptions.NetlinkError as nle: + raise RuntimeError(f"{txt} -{nle.code}") from nle + return inner + return decorator + + +class Listener: + def __init__(self, observer) -> None: + self.observer = observer + self.nl80211 = pyroute2.netlink.nl80211.NL80211() + + @nl_except_to_runtime_err("starting listener failed") + def start(self) -> None: + self.nl80211.bind() + # The "scan" multicast group provides notifications for "TRIGGER_SCAN" + # and "NEW_SCAN_RESULTS". + self.nl80211.add_membership("scan") + # The "mlme" multicast group provides notifications for + # "ASSOCIATE", "AUTHENTICATE", "CONNECT", "DISCONNECT", + # "DEAUTHENTICATE", ... + self.nl80211.add_membership("mlme") + + # Request a dump of all WLAN interfaces to get us started. + # This will produce "NEW_INTERFACE" events. + msg = nl80211cmd() + msg["cmd"] = NL80211_NAMES["NL80211_CMD_GET_INTERFACE"] + + responses = self.nl80211.nlm_request( + msg, msg_type=self.nl80211.prid, + msg_flags=NLM_F_REQUEST | NLM_F_DUMP + ) + + for response in responses: + self.event_handler(response) + + def fileno(self) -> int: + return self.nl80211.fileno() + + def dump_scan_results( + self, ifindex: int, only_connected: bool + ) -> list[tuple[bytes, str]]: + """Return a list of (ssid, status)""" + + msg = nl80211cmd() + + msg["cmd"] = NL80211_NAMES["NL80211_CMD_GET_SCAN"] + msg["attrs"] = [["NL80211_ATTR_IFINDEX", ifindex]] + + responses = self.nl80211.nlm_request( + msg, msg_type=self.nl80211.prid, + msg_flags=NLM_F_REQUEST | NLM_F_DUMP + ) + + ssids: list[tuple[bytes, str]] = [] + for response in responses: + if (bss := response.get_attr("NL80211_ATTR_BSS")) is None: + continue + + status = "no status" + if (bss_status := bss.get_attr("NL80211_BSS_STATUS")) is not None: + if bss_status == NL80211_BSS_STATUS_ASSOCIATED: + status = "Connected" + elif bss_status == NL80211_BSS_STATUS_AUTHENTICATED: + status = "Authenticated" + elif bss_status == NL80211_BSS_STATUS_IBSS_JOINED: + status = "Joined" + else: + if only_connected: + continue + + if (ssid := bss.get_nested("NL80211_BSS_INFORMATION_ELEMENTS", + "SSID")): + ssids.append((ssid, status)) + + return ssids + + def event_handler(self, event: nl80211cmd) -> None: + """Invoke the wlan_event function from the observer, optionally + including a scan result.""" + ifindex: int | None = event.get_attr("NL80211_ATTR_IFINDEX") + + cmd = None + if "event" in event: + cmd = event["event"] + + # To behave the same as the old _nl80211module, we set ifindex=-1 when + # ifindex is not provided. Going forward though, we should probably + # leave it as None. + # Also, the old implementation passed cmd="NL80211_CMD_UNKNOWN" when + # cmd is unknown, so let's treat the value specially. + arg: dict[str, Any] = { + "cmd": ( + cmd.removeprefix("NL80211_CMD_") + if cmd is not None + else "NL80211_CMD_UNKNOWN" + ), + "ifindex": ifindex if ifindex is not None else -1, + } + + if ifindex is not None: + if cmd == "NL80211_CMD_NEW_SCAN_RESULTS": + arg["ssids"] = self.dump_scan_results( + ifindex=ifindex, only_connected=False + ) + elif cmd in ("NL80211_CMD_ASSOCIATE", "NL80211_CMD_NEW_INTERFACE"): + arg["ssids"] = self.dump_scan_results( + ifindex=ifindex, only_connected=True + ) + + self.observer.wlan_event(arg) + + def data_ready(self) -> None: + for event in self.nl80211.get(): + self.event_handler(event) + + @nl_except_to_runtime_err("triggering scan failed") + def trigger_scan(self, ifindex: int) -> None: + msg = nl80211cmd() + + msg["cmd"] = NL80211_NAMES["NL80211_CMD_TRIGGER_SCAN"] + msg["attrs"] = [["NL80211_ATTR_IFINDEX", ifindex]] + + self.nl80211.nlm_request( + msg, msg_type=self.nl80211.prid, + msg_flags=NLM_F_REQUEST | NLM_F_ACK + ) diff --git a/probert/rtnetlink/__init__.py b/probert/rtnetlink/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/probert/rtnetlink/addr.py b/probert/rtnetlink/addr.py new file mode 100644 index 0000000..d0377a0 --- /dev/null +++ b/probert/rtnetlink/addr.py @@ -0,0 +1,105 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" This module is part of the pyroute2-based rewrite of _rtnetlinkmodule.c +(which was a C implementation using libnl). +""" + +import dataclasses +import ipaddress +import typing + +from pyroute2.netlink import nlmsg + +from probert.rtnetlink.cache import Cache, CacheEntry, CacheEntryComparer + + +def build_event_data(msg: nlmsg) -> dict[str, typing.Any]: + data = { + "ifindex": msg["index"], + # msg["flags"] (i.e., ifaddrmsg.ifa_flags) is a 8-bits integer and + # can only store some of the flags. The IFA_FLAGS attribute is an + # extension that supports 32-bits flags. + # See rtnetlink (7) + "flags": msg.get_attr("IFA_FLAGS", msg["flags"]), + "family": msg["family"], + "scope": msg["scope"], + } + + # * For IPv4, the local address is stored in IFA_LOCAL. + # * For IPv6, the local address is in IFA_ADDRESS and IFA_LOCAL does + # not exist. + # See libnl implementation for details. + local_addr = msg.get_attr("IFA_LOCAL", msg.get_attr("IFA_ADDRESS")) + pfxlen = msg["prefixlen"] + if_local_addr = ipaddress.ip_interface(f"{local_addr}/{pfxlen}") + if if_local_addr.max_prefixlen == pfxlen: + local_addr = if_local_addr.ip.compressed + else: + local_addr = if_local_addr.compressed + # For some reason, probert uses decode("latin-1") so let's comply + # ... + data["local"] = local_addr.encode("latin-1") + + return data + + +class AddrCache(Cache): + @dataclasses.dataclass(frozen=True) + class UniqueIdentifier: + """How to uniquely identify an address. This class is used as the key + in the addr cache. + For more information, see in libnl: + .oo_id_attrs_get = addr_id_attrs_get, + .oo_id_attrs = (ADDR_ATTR_FAMILY | ADDR_ATTR_IFINDEX | + ADDR_ATTR_LOCAL | ADDR_ATTR_PREFIXLEN) + """ + ifindex: int + family: int + prefixlen: int + # In theory we want: local and optionally peer (depending on family) + # But let's just include IFA_ADDRESS, IFA_LOCAL + ifa_local: str | None + ifa_address: str | None + + @classmethod + def from_nl_msg(cls, msg: nlmsg) -> "AddrCache.UniqueIdentifier": + return cls( + ifindex=msg["index"], + family=msg["family"], + prefixlen=msg["prefixlen"], + ifa_address=msg.get_attr("IFA_ADDRESS"), + ifa_local=msg.get_attr("IFA_LOCAL"), + ) + + @staticmethod + def are_entries_equal(a: CacheEntry, b: CacheEntry) -> bool: + fields_to_compare = [ + CacheEntryComparer.direct("index"), + CacheEntryComparer.direct("family"), + CacheEntryComparer.direct("scope"), + CacheEntryComparer.attr("IFA_LABEL"), + # local (and peer) addresses. + CacheEntryComparer.direct("prefixlen"), + CacheEntryComparer.attr("IFA_ADDRESS"), + CacheEntryComparer.attr("IFA_LOCAL"), + CacheEntryComparer.attr("IFA_MULTICAST"), + CacheEntryComparer.attr("IFA_BROADCAST"), + CacheEntryComparer.attr("IFA_ANYCAST"), + CacheEntryComparer.attr("IFA_CACHEINFO"), + # flags (IFA_FLAGS is a 32-bits extension) + CacheEntryComparer.direct("flags"), + CacheEntryComparer.attr("IFA_FLAGS"), + ] + return CacheEntryComparer.are_equal(a, b, fields=fields_to_compare) diff --git a/probert/rtnetlink/cache.py b/probert/rtnetlink/cache.py new file mode 100644 index 0000000..5a3dbe9 --- /dev/null +++ b/probert/rtnetlink/cache.py @@ -0,0 +1,88 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" This module is part of the pyroute2-based rewrite of _rtnetlinkmodule.c +(which was a C implementation using libnl). +""" + +import abc +import collections +import dataclasses +import typing + +from pyroute2.netlink import nlmsg + +# In the cache, we store the whole netlink message. +# But only the relevant fields are checked for "equality". +CacheEntry: typing.TypeAlias = nlmsg + + +class CacheEntryComparer: + """Helpers to compare the content of two entries from the cache.""" + @staticmethod + def direct(name: str): + def inner(msg): + return msg[name] + return inner + + @staticmethod + def attr(name: str): + def inner(msg): + return msg.get_attr(name) + return inner + + @staticmethod + def nested_attr(names: list[str]): + def inner(msg): + v = msg + for name in names: + v = v.get_attr(name) + if v is None: + return None + return v + return inner + + @staticmethod + def attr_foreach_value(name: str, callback): + def inner(msg): + attr = msg.get_attr(name) + if attr is None: + return None + return [callback(item) for item in attr] + + return inner + + @staticmethod + def are_equal( + entry_a: CacheEntry, entry_b: CacheEntry, *, + fields: list[typing.Callable[[CacheEntry], bool]]) -> bool: + result = True + for attr_cb in fields: + if attr_cb(entry_a) != attr_cb(entry_b): + result = False + return result + + +class Cache(collections.UserDict, abc.ABC): + @dataclasses.dataclass(frozen=True) + class UniqueIdentifier(abc.ABC): + @classmethod + @abc.abstractmethod + def from_nl_msg(cls, msg: nlmsg) -> "Cache.UniqueIdentifier": + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def are_entries_equal(a: CacheEntry, b: CacheEntry) -> bool: + raise NotImplementedError diff --git a/probert/rtnetlink/link.py b/probert/rtnetlink/link.py new file mode 100644 index 0000000..ac8b9dc --- /dev/null +++ b/probert/rtnetlink/link.py @@ -0,0 +1,91 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" This module is part of the pyroute2-based rewrite of _rtnetlinkmodule.c +(which was a C implementation using libnl). +""" + +import dataclasses +import typing + +from pyroute2.netlink import nlmsg + +from probert.rtnetlink.cache import Cache, CacheEntry, CacheEntryComparer + + +def build_event_data(msg: nlmsg) -> dict[str, typing.Any]: + link_info = msg.get_attr("IFLA_LINKINFO") + if link_info: + is_vlan = link_info.get_attr("IFLA_INFO_KIND") == "vlan" + else: + is_vlan = False + data = { + "ifindex": msg["index"], + "flags": msg["flags"], + "arptype": msg["ifi_type"], + # This differs from The previous implementation (using libnl) in that + # we don't override family based on IFLA_LINKINFO -> IFLA_INFO_KIND. + "family": msg["family"], + "is_vlan": is_vlan, + "name": msg.get_attr("IFLA_IFNAME").encode("utf-8"), + } + if data["is_vlan"]: + data["vlan_id"] = link_info.get_attr( + "IFLA_INFO_DATA").get_attr("IFLA_VLAN_ID") + data["vlan_link"] = msg.get_attr("IFLA_LINK") + return data + + +class LinkCache(Cache): + @dataclasses.dataclass(frozen=True) + class UniqueIdentifier: + """How to uniquely identify a link. This class is used as the key in + the link cache. + For more information, see in libnl: + .oo_id_attrs = LINK_ATTR_IFINDEX | LINK_ATTR_FAMILY + """ + ifindex: int + family: int + + @classmethod + def from_nl_msg(cls, msg: nlmsg) -> "LinkCache.UniqueIdentifier": + return cls(ifindex=msg["index"], family=msg["family"]) + + @staticmethod + def are_entries_equal(a: CacheEntry, b: CacheEntry) -> bool: + fields_to_compare = [ + CacheEntryComparer.direct("index"), + CacheEntryComparer.attr("IFLA_MTU"), + CacheEntryComparer.attr("IFLA_LINK"), + CacheEntryComparer.attr("IFLA_LINK_NETNSID"), + CacheEntryComparer.attr("IFLA_TXQLEN"), + CacheEntryComparer.attr("IFLA_WEIGHT"), + CacheEntryComparer.attr("IFLA_MASTER"), + CacheEntryComparer.direct("family"), + CacheEntryComparer.attr("IFLA_LINKMODE"), + CacheEntryComparer.attr("IFLA_QDISC"), + CacheEntryComparer.attr("IFLA_IFNAME"), + CacheEntryComparer.attr("IFLA_ADDRESS"), + CacheEntryComparer.attr("IFLA_BROADCAST"), + CacheEntryComparer.attr("IFLA_IFALIAS"), + CacheEntryComparer.attr("IFLA_NUM_VF"), + CacheEntryComparer.attr("IFLA_PROMISCUITY"), + CacheEntryComparer.attr("IFLA_NUM_TX_QUEUES"), + CacheEntryComparer.attr("IFLA_NUM_RX_QUEUES"), + CacheEntryComparer.direct("flags"), + # NOTE: For completeness, we should also look at protoinfo and + # infodata. But these are implementation specific so let's ignore + # them for now. + ] + return CacheEntryComparer.are_equal(a, b, fields=fields_to_compare) diff --git a/probert/rtnetlink/listener.py b/probert/rtnetlink/listener.py new file mode 100644 index 0000000..f6ca4af --- /dev/null +++ b/probert/rtnetlink/listener.py @@ -0,0 +1,178 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" This module is part of the pyroute2-based rewrite of _rtnetlinkmodule.c +(which was a C implementation using libnl). +""" + +import dataclasses +import enum +import typing + +import pyroute2 +from pyroute2.netlink import nlmsg +from pyroute2.netlink.rtnl.ifinfmsg import IFF_UP + +import probert.rtnetlink.addr +import probert.rtnetlink.link +import probert.rtnetlink.route +from probert.rtnetlink.cache import Cache, CacheEntry + + +class EventResult(enum.Enum): + """Enumerates the different outcomes that an event can produce.""" + NEW = "NEW" # Send a NEW event to the observer + CHANGE = "CHANGE" # Send a CHANGE event to the observer + DEL = "DEL" # Send a DEL event to the observer + + DISCARD = "DISCARD" # Do not send any event to the observer + + +class Listener: + @dataclasses.dataclass + class MsgHandler: + new: str + cache: Cache + observer_callback: typing.Callable[[str, dict[str, typing.Any]], None] + change_callback: typing.Callable[[CacheEntry, CacheEntry], None] | None + build_event_data: typing.Callable[[nlmsg], dict[str, typing.Any]] + + def cache_handle_nl_msg(self, msg: nlmsg) -> EventResult: + identifier = self.cache.UniqueIdentifier.from_nl_msg(msg) + if msg["event"] == self.new: + if identifier not in self.cache: + self.cache[identifier] = msg + return EventResult.NEW + + if self.cache.are_entries_equal(self.cache[identifier], msg): + # We still update the cache. Values are not necessarily + # meaningful but they are more up to date. + self.cache[identifier] = msg + return EventResult.DISCARD + + if self.change_callback is not None: + self.change_callback(self.cache[identifier], msg) + + self.cache[identifier] = msg + return EventResult.CHANGE + else: + self.cache.pop(identifier, None) + return EventResult.DEL + + def on_link_change(self, old_link: CacheEntry, + new_link: CacheEntry) -> None: + # When an interface goes down, the kernel does not send RTM_DELROUTE + # message for all routes involving the interface. + # We still need to notify the observer that such routes are no longer + # accessible. + # See https://github.com/thom311/libnl/issues/340 + if new_link["flags"] & IFF_UP or not old_link["flags"] & IFF_UP: + return + + ifindex = new_link["index"] + + # Collect the routes to remove first, so we don't invalidate iterators. + routes_to_del = [] + for route_idx, route in self.route_cache.items(): + if probert.rtnetlink.route.get_ifindex(route) == ifindex: + routes_to_del.append(route_idx) + + for route_to_del in routes_to_del: + route = self.route_cache.pop(route_to_del) + self.observer.route_change( + "DEL", probert.rtnetlink.route.build_event_data(route)) + + def __init__(self, observer) -> None: + self.observer = observer + + # By default, the groups (aka. membership groups) is RTMGRP_DEFAULT, + # which includes neighbours, traffic control, MPLS, rules, etc. We + # don't want to receive notifications for those. + groups = ( + pyroute2.netlink.rtnl.RTMGRP_LINK + | pyroute2.netlink.rtnl.RTMGRP_IPV4_IFADDR + | pyroute2.netlink.rtnl.RTMGRP_IPV6_IFADDR + | pyroute2.netlink.rtnl.RTMGRP_IPV4_ROUTE + | pyroute2.netlink.rtnl.RTMGRP_IPV6_ROUTE + ) + + self.ipr = pyroute2.IPRoute(groups=groups) + + # The caches allow us to discard repetitive NEW events or to emit + # CHANGE events when appropriate. + self.link_cache = probert.rtnetlink.link.LinkCache() + self.addr_cache = probert.rtnetlink.addr.AddrCache() + self.route_cache = probert.rtnetlink.route.RouteCache() + + self.msg_handlers = { + "RTM_NEWLINK": self.MsgHandler( + new="RTM_NEWLINK", + cache=self.link_cache, + observer_callback=self.observer.link_change, + build_event_data=probert.rtnetlink.link.build_event_data, + change_callback=self.on_link_change, + ), "RTM_NEWADDR": self.MsgHandler( + new="RTM_NEWADDR", + cache=self.addr_cache, + observer_callback=self.observer.addr_change, + build_event_data=probert.rtnetlink.addr.build_event_data, + change_callback=None, + ), "RTM_NEWROUTE": self.MsgHandler( + new="RTM_NEWROUTE", + cache=self.route_cache, + observer_callback=self.observer.route_change, + build_event_data=probert.rtnetlink.route.build_event_data, + change_callback=None, + ), + } + self.msg_handlers["RTM_DELLINK"] = self.msg_handlers["RTM_NEWLINK"] + self.msg_handlers["RTM_DELADDR"] = self.msg_handlers["RTM_NEWADDR"] + self.msg_handlers["RTM_DELROUTE"] = self.msg_handlers["RTM_NEWROUTE"] + + def start(self) -> None: + # By default IPRoute adds membership for RTMGRP_LINK + self.ipr.bind() + + for msg in self.ipr.get_links(): + self.handle_nl_msg(msg, emit_change=False) + for msg in self.ipr.get_addr(): + self.handle_nl_msg(msg, emit_change=False) + for msg in self.ipr.get_routes(): + self.handle_nl_msg(msg, emit_change=False) + + def fileno(self) -> int: + return self.ipr.fileno() + + def handle_nl_msg(self, msg: nlmsg, *, emit_change=True) -> None: + handler = self.msg_handlers[msg["event"]] + result = handler.cache_handle_nl_msg(msg) + + if result == EventResult.DISCARD: + return + + # Useful when populating the cache the first time. + if result == EventResult.CHANGE and not emit_change: + return + + handler.observer_callback(result.value, handler.build_event_data(msg)) + + def data_ready(self) -> None: + for msg in self.ipr.get(): + self.handle_nl_msg(msg) + + def set_link_flags(self, ifindex: int, flags: int) -> None: + self.ipr.link("set", index=ifindex, flags=flags, mask=flags) + + def unset_link_flags(self, ifindex: int, flags: int) -> None: + self.ipr.link('set', index=ifindex, flags=0x0, mask=flags) diff --git a/probert/rtnetlink/route.py b/probert/rtnetlink/route.py new file mode 100644 index 0000000..609485b --- /dev/null +++ b/probert/rtnetlink/route.py @@ -0,0 +1,143 @@ + +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" This module is a pyroute2-based rewrite of _rtnetlinkmodule.c (which was a +C implementation using libnl). +""" + +import dataclasses +import ipaddress +import typing + +from pyroute2.netlink import nlmsg + +from probert.rtnetlink.cache import Cache, CacheEntry, CacheEntryComparer + + +def get_ifindex(msg: nlmsg) -> int | None: + multipath = msg.get_attr("RTA_MULTIPATH") + if multipath is not None: + # A bit cheaty to ignore multipath but ... + return multipath[0]["oif"] + else: + return msg.get_attr("RTA_OIF") + return None + + +def build_event_data(msg: nlmsg) -> dict[str, typing.Any]: + if not msg["dst_len"]: + dst = "default" + else: + addr = msg.get_attr("RTA_DST") + pfxlen = msg["dst_len"] + network = ipaddress.ip_network(f"{addr}/{pfxlen}") + if network.max_prefixlen == pfxlen: + dst = network.network_address.compressed + else: + dst = network.compressed + + ifindex = get_ifindex(msg) + + return { + "family": msg["family"], + "type": msg["type"], + "table": msg["table"], + "dst": dst.encode("utf-8"), + "ifindex": ifindex if ifindex is not None else -1, + } + + +class RouteCache(Cache): + @dataclasses.dataclass(frozen=True) + class UniqueIdentifier: + """How to uniquely identify a route. This class is used as the key in + the route cache. + For more information, see in libnl: + .oo_id_attrs = (ROUTE_ATTR_FAMILY | ROUTE_ATTR_TOS | + ROUTE_ATTR_TABLE | ROUTE_ATTR_DST | + ROUTE_ATTR_PRIO), + .oo_id_attrs_get = route_id_attrs_get + """ + family: int + tos: int + table: int + dst: str | None + prio: int | None # None for MPLS + # NOTE: Multiple special routes (e.g. multicast routes) can have the + # same destination address but a different output interface (i.e., + # RTA_OIF). They should probably not be considered the same route (and + # therefore RTA_OIF should probably be part of the unique identifier). + # But our previous implementation based on libnl didn't have that + # today so we're mimicking the behavior. + # As a result, in the example below, the second route might potentially + # be discarded since the two routes have the same unique identifier. + # $ ip -6 route show table 255 + # multicast ff00::/8 dev lxdbr0 proto kernel metric 256 pref medium + # multicast ff00::/8 dev dummy2 proto kernel metric 256 pref medium + + @classmethod + def from_nl_msg(cls, msg: nlmsg) -> "RouteCache.UniqueIdentifier": + return cls( + family=msg["family"], + tos=msg["tos"], + table=msg["table"], + dst=msg.get_attr("RTA_DST"), + prio=msg.get_attr("RTA_PRIORITY"), + ) + + @staticmethod + def are_entries_equal(a: CacheEntry, b: CacheEntry) -> bool: + def nexthop_multipath(item) -> list[typing.Any]: + return [ + item["oif"], + item["hops"], + item.get_attr("RTA_GATEWAY"), + item.get_attr("RTA_FLOW"), + item.get_attr("RTA_NEWDST"), + item.get_attr("RTA_VIA"), + ] + + fields_to_compare = [ + CacheEntryComparer.direct("family"), + CacheEntryComparer.direct("tos"), + CacheEntryComparer.direct("table"), + CacheEntryComparer.direct("proto"), + CacheEntryComparer.direct("scope"), + CacheEntryComparer.direct("type"), + CacheEntryComparer.attr("RTA_PRIORITY"), + CacheEntryComparer.attr("RTA_DST"), + CacheEntryComparer.attr("RTA_SRC"), + CacheEntryComparer.attr("RTA_IIF"), + CacheEntryComparer.attr("RTA_PREFSRC"), + CacheEntryComparer.attr("RTA_TTL_PROPAGATE"), + CacheEntryComparer.attr("RTA_METRICS"), + CacheEntryComparer.direct("flags"), + + # Nexthop without multipath + CacheEntryComparer.attr("RTA_OIF"), + CacheEntryComparer.attr("RTA_GATEWAY"), + CacheEntryComparer.attr("RTA_FLOW"), + CacheEntryComparer.attr("RTA_NEWDST"), + CacheEntryComparer.attr("RTA_VIA"), + + # Nexthops with Multipath + CacheEntryComparer.attr_foreach_value("RTA_MULTIPATH", + nexthop_multipath) + + # NOTE: For completeness, we should also dig into the RTA_ENCAP + # nested attribute but this contains implementation specific + # attributes that are unlikely relevant for us. + ] + return CacheEntryComparer.are_equal(a, b, fields=fields_to_compare) diff --git a/probert/tests/helpers.py b/probert/tests/helpers.py index 705de51..56dd0bd 100644 --- a/probert/tests/helpers.py +++ b/probert/tests/helpers.py @@ -14,34 +14,17 @@ # along with this program. If not, see . import contextlib -import imp -import importlib import random import string import unittest -def builtin_module_name(): - options = ('builtins', '__builtin__') - for name in options: - try: - imp.find_module(name) - except ImportError: - continue - else: - print('importing and returning: %s' % name) - importlib.import_module(name) - return name - - @contextlib.contextmanager def simple_mocked_open(content=None): if not content: content = '' m_open = unittest.mock.mock_open(read_data=content) - mod_name = builtin_module_name() - m_patch = '{}.open'.format(mod_name) - with unittest.mock.patch(m_patch, m_open, create=True): + with unittest.mock.patch('builtins.open', m_open, create=True): yield m_open diff --git a/probert/tests/rtnetlink/__init__.py b/probert/tests/rtnetlink/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/probert/tests/rtnetlink/common.py b/probert/tests/rtnetlink/common.py new file mode 100644 index 0000000..fccdcae --- /dev/null +++ b/probert/tests/rtnetlink/common.py @@ -0,0 +1,42 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import typing +from collections import UserDict + +"""This module provides a helpers to make pyroute2's ipmock classes +behave like nlmsg when it comes to accessing data.""" + + +def get_attr(data: dict[str, typing.Any], + name: str, default=None) -> typing.Any: + for attr_name, attr_val in data["attrs"]: + if attr_name == name: + if isinstance(attr_val, dict) and "attrs" in attr_val: + return AttrList(attr_val) + return attr_val + return default + + +class AttrList(UserDict): + def get_attr(self, *args, **kwargs) -> typing.Any: + return get_attr(self.data, *args, **kwargs) + + +class WithGetAttrMixin: + def __getitem__(self, name: str) -> typing.Any: + return self.export()[name] + + def get_attr(self, *args, **kwargs) -> typing.Any: + return get_attr(self.export(), *args, **kwargs) diff --git a/probert/tests/rtnetlink/test_addr.py b/probert/tests/rtnetlink/test_addr.py new file mode 100644 index 0000000..183f4cb --- /dev/null +++ b/probert/tests/rtnetlink/test_addr.py @@ -0,0 +1,143 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import socket +import unittest + +from pyroute2.iproute.ipmock import MockAddress + +from probert.rtnetlink.addr import AddrCache, build_event_data +from probert.tests.rtnetlink.common import WithGetAttrMixin + + +class MyMockAddr(WithGetAttrMixin, MockAddress): + """Subclass of pyroute2's MockAddr that makes it behave like a nlmsg""" + + +class TestAddrBuildEventData(unittest.TestCase): + def test_inet(self): + msg = MyMockAddr(index=1, local="192.168.1.1", address="192.168.1.1", + prefixlen=24, family=socket.AF_INET.value) + + self.assertEqual({ + "ifindex": 1, + "flags": 512, # hardcoded in MockAddress + "family": socket.AF_INET.value, + "scope": 0, # hardcoded in MockAddress + "local": b"192.168.1.1/24", + }, build_event_data(msg)) + + def test_inet6(self): + msg = MyMockAddr(index=1, local=None, address="abcd::10", + prefixlen=64, family=socket.AF_INET6.value) + + self.assertEqual({ + "ifindex": 1, + "flags": 512, # hardcoded in MockAddress + "family": socket.AF_INET6.value, + "scope": 0, # hardcoded in MockAddress + "local": b"abcd::10/64", + }, build_event_data(msg)) + + def test_inet_point_to_point(self): + # There is possible confusion between IFA_ADDRESS and IFA_LOCAL, make + # sure we use the right value. The other would be the peer address. + msg = MyMockAddr(index=1, address="192.168.1.2", local="192.168.1.1", + prefixlen=31, family=socket.AF_INET6.value) + + self.assertEqual({ + "ifindex": 1, + "flags": 512, # hardcoded in MockAddress + "family": socket.AF_INET6.value, + "scope": 0, # hardcoded in MockAddress + "local": b"192.168.1.1/31", + }, build_event_data(msg)) + + +class TestAddrCache(unittest.TestCase): + def test_unique_identifier_from_nl_msg__inet(self): + self.assertEqual( + AddrCache.UniqueIdentifier( + ifindex=3, family=socket.AF_INET.value, prefixlen=24, + ifa_local="192.168.1.1", ifa_address="192.168.1.2"), + AddrCache.UniqueIdentifier.from_nl_msg(MyMockAddr( + index=3, family=socket.AF_INET.value, prefixlen=24, + local="192.168.1.1", address="192.168.1.2")), + ) + + def test_unique_identifier_from_nl_msg__inet6(self): + self.assertEqual( + AddrCache.UniqueIdentifier( + ifindex=4, family=socket.AF_INET6.value, prefixlen=72, + ifa_address=None, ifa_local="aaaa::1"), + AddrCache.UniqueIdentifier.from_nl_msg(MyMockAddr( + index=4, family=socket.AF_INET6.value, prefixlen=72, + address=None, local="aaaa::1")) + ) + + def test_are_entries_equal__equal(self): + # Identical addresses are considered equal + self.assertTrue(AddrCache.are_entries_equal( + MyMockAddr(index=1, family=socket.AF_INET.value, prefixlen=16, + local="10.8.1.1"), + MyMockAddr(index=1, family=socket.AF_INET.value, prefixlen=16, + local="10.8.1.1"), + )) + self.assertTrue(AddrCache.are_entries_equal( + MyMockAddr(index=2, family=socket.AF_INET6.value, prefixlen=127, + local=None, address="abcd::1"), + MyMockAddr(index=2, family=socket.AF_INET6.value, prefixlen=127, + local=None, address="abcd::1"), + )) + + def test_are_entries_equal__differ(self): + # addresses differ (these will be two separate cache entries though) + self.assertFalse(AddrCache.are_entries_equal( + MyMockAddr(index=3, family=socket.AF_INET.value, prefixlen=24, + local="10.8.1.1"), + MyMockAddr(index=3, family=socket.AF_INET.value, prefixlen=24, + local="10.8.1.2"), + )) + + # prefixes differ (these will be two separate cache entries though) + self.assertFalse(AddrCache.are_entries_equal( + MyMockAddr(index=3, family=socket.AF_INET6.value, prefixlen=64, + address="aaaa::1"), + MyMockAddr(index=3, family=socket.AF_INET6.value, prefixlen=72, + address="aaaa::1"), + )) + + # ifindexes differ + self.assertFalse(AddrCache.are_entries_equal( + MyMockAddr(index=3, family=socket.AF_INET.value, prefixlen=24, + local="192.168.0.10"), + MyMockAddr(index=4, family=socket.AF_INET.value, prefixlen=24, + local="192.168.0.10"), + )) + + # broadcast addresses differ + self.assertFalse(AddrCache.are_entries_equal( + MyMockAddr(index=4, family=socket.AF_INET.value, prefixlen=24, + local="192.168.0.10", broadcast="192.168.0.255"), + MyMockAddr(index=4, family=socket.AF_INET.value, prefixlen=24, + local="192.168.0.10", broadcast="192.168.0.250"), + )) + + # labels differ + self.assertFalse(AddrCache.are_entries_equal( + MyMockAddr(index=5, family=socket.AF_INET.value, prefixlen=24, + local="192.168.0.10", label="mylabel"), + MyMockAddr(index=5, family=socket.AF_INET.value, prefixlen=24, + local="192.168.0.10"), + )) diff --git a/probert/tests/rtnetlink/test_link.py b/probert/tests/rtnetlink/test_link.py new file mode 100644 index 0000000..653eaee --- /dev/null +++ b/probert/tests/rtnetlink/test_link.py @@ -0,0 +1,102 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import socket +import unittest + +from pyroute2.iproute.ipmock import MockLink + +from probert.rtnetlink.link import LinkCache, build_event_data +from probert.tests.rtnetlink.common import WithGetAttrMixin + + +class MyMockLink(WithGetAttrMixin, MockLink): + """Subclass of pyroute2's MockLink that makes it behave like a nlmsg""" + + +class TestLinkBuildEventData(unittest.TestCase): + def test_ethernet(self): + msg = MyMockLink(index=1, flags=0x0, ifname="eth0") + + self.assertEqual({ + "ifindex": 1, + "flags": 0x0, + "arptype": 772, # hardcoded by MockLink + "family": socket.AF_UNSPEC.value, # hardcoded by MockLink + "is_vlan": False, + "name": b"eth0", + }, build_event_data(msg)) + + def test_vlan(self): + msg = MyMockLink(index=334, flags=0x0, ifname="vlan20@eth0", + link="eth0", kind="vlan", vlan_id=20) + + self.assertEqual( + {"ifindex": 334, + "flags": 0x0, + "arptype": 772, # hardcoded by MockLink + "family": socket.AF_UNSPEC.value, # hardcoded by MockLink + "is_vlan": True, + "vlan_link": "eth0", + "vlan_id": 20, + "name": b"vlan20@eth0"}, build_event_data(msg)) + + +class TestLinkCache(unittest.TestCase): + def test_unique_identifier_from_nl_msg(self): + self.assertEqual( + LinkCache.UniqueIdentifier( + ifindex=3, family=socket.AF_UNSPEC.value), + LinkCache.UniqueIdentifier.from_nl_msg(MyMockLink(index=3))) + + def test_are_entries_equal__equal(self): + # Completely identical links are considered equal + self.assertTrue(LinkCache.are_entries_equal( + MyMockLink(index=1, flags=0x0, ifname="eth0"), + MyMockLink(index=1, flags=0x0, ifname="eth0"), + )) + + # Links where only stats differ are considered equal + self.assertTrue(LinkCache.are_entries_equal( + MyMockLink(index=1, flags=0x0, ifname="eth0", + rx_packets=100, tx_packets=100), + MyMockLink(index=1, flags=0x0, ifname="eth0"), + )) + self.assertTrue(LinkCache.are_entries_equal( + MyMockLink(index=1, flags=0x0, ifname="eth0", + rx_packets=1000, tx_packets=1000), + MyMockLink(index=1, flags=0x0, ifname="eth0", + rx_packets=2000, tx_packets=2000), + )) + + def test_are_entries_equal__differ(self): + # flags differ + self.assertFalse(LinkCache.are_entries_equal( + MyMockLink(index=1, flags=0x0, ifname="eth0"), + MyMockLink(index=1, flags=0x1, ifname="eth0"), + )) + + # addresses differ + self.assertFalse(LinkCache.are_entries_equal( + MyMockLink(index=1, flags=0x0, ifname="eth0", + address="11:11:11:11:11:11"), + MyMockLink(index=1, flags=0x0, ifname="eth0", + address="22:22:22:22:22:22"), + )) + + # interfaces differ (these will be separate cache entries though) + self.assertFalse(LinkCache.are_entries_equal( + MyMockLink(index=1, flags=0x0, ifname="eth0"), + MyMockLink(index=2, flags=0x0, ifname="eth1"), + )) diff --git a/probert/tests/rtnetlink/test_listener.py b/probert/tests/rtnetlink/test_listener.py new file mode 100644 index 0000000..bf5969a --- /dev/null +++ b/probert/tests/rtnetlink/test_listener.py @@ -0,0 +1,219 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import socket +import unittest +from unittest import mock + +from pyroute2 import IPRoute +from pyroute2.netlink.rtnl.ifinfmsg import IFF_UP + +from probert.rtnetlink.link import LinkCache +from probert.rtnetlink.listener import EventResult, Listener +from probert.rtnetlink.route import RouteCache +from probert.rtnetlink.route import build_event_data as route_build_event_data +from probert.tests.rtnetlink.test_link import MyMockLink +from probert.tests.rtnetlink.test_route import MyMockRoute + + +class TestListener(unittest.TestCase): + def setUp(self): + self.listener = Listener(mock.Mock()) + + def test_msg_handler_cache_handle_nl_msg__new_not_in_cache(self): + handler = self.listener.msg_handlers["RTM_NEWLINK"] + link = MyMockLink() + + self.assertEqual(EventResult.NEW, + handler.cache_handle_nl_msg(link)) + + def test_msg_handler_cache_handle_nl_msg__new_in_cache_not_updated(self): + handler = self.listener.msg_handlers["RTM_NEWLINK"] + link = MyMockLink() + identifier = LinkCache.UniqueIdentifier.from_nl_msg(link) + + self.listener.link_cache[identifier] = link + + self.assertEqual(EventResult.DISCARD, + handler.cache_handle_nl_msg(link)) + + def test_msg_handler_cache_handle_nl_msg__new_in_cache_updated(self): + handler = self.listener.msg_handlers["RTM_NEWLINK"] + l1 = MyMockLink(index=1, address="aa:aa:aa:aa:aa:aa") + l2 = MyMockLink(index=1, address="bb:bb:bb:bb:bb:bb") + identifier = LinkCache.UniqueIdentifier.from_nl_msg(l1) + + self.listener.link_cache[identifier] = l1 + + self.assertEqual(EventResult.CHANGE, handler.cache_handle_nl_msg(l2)) + + def test_msg_handler_cache_handle_nl_msg__del(self): + handler = self.listener.msg_handlers["RTM_DELLINK"] + link = MyMockLink() + + export = link.export() + export["event"] = "RTM_DELLINK" + + with mock.patch.object(link, "export", return_value=export): + self.assertEqual(EventResult.DEL, + handler.cache_handle_nl_msg(link)) + + def test_on_link_change__no_change_state(self): + with mock.patch.object(self.listener.route_cache, "items") as m_items: + self.listener.on_link_change(MyMockLink(index=1), + MyMockLink(index=1)) + + m_items.assert_not_called() + + def test_on_link_change__change_state(self): + old_link = MyMockLink(index=41, flags=IFF_UP) + new_link = MyMockLink(index=41, flags=0x0) + + routes = [ + MyMockRoute(dst="192.168.1.0", dst_len=24, + family=socket.AF_INET.value, oif=41), + MyMockRoute(dst="192.168.2.0", dst_len=24, + family=socket.AF_INET.value, oif=42), + MyMockRoute(dst="aaaa::", dst_len=64, + family=socket.AF_INET6.value, oif=41), + ] + for route in routes: + identifier = RouteCache.UniqueIdentifier.from_nl_msg(route) + self.listener.route_cache[identifier] = route + + self.listener.on_link_change(old_link, new_link) + self.assertEqual( + [ + mock.call("DEL", route_build_event_data(routes[0])), + mock.call("DEL", route_build_event_data(routes[2])), + ], self.listener.observer.route_change.mock_calls, + ) + + def test_start(self): + p_bind = mock.patch.object(self.listener.ipr, "bind") + p_links = mock.patch.object(self.listener.ipr, "get_links", + return_value=["l1", "l2"]) + p_addr = mock.patch.object(self.listener.ipr, "get_addr", + return_value=["a1", "a2"]) + p_routes = mock.patch.object(self.listener.ipr, "get_routes", + return_value=["r1", "r2"]) + p_handle = mock.patch.object(self.listener, "handle_nl_msg") + + with p_bind as m_bind, p_links as m_links, p_addr as m_addr, \ + p_routes as m_routes, p_handle as m_handle: + self.listener.start() + + m_bind.assert_called_once_with() + m_links.assert_called_once_with() + m_addr.assert_called_once_with() + m_routes.assert_called_once_with() + + self.assertEqual( + [ + mock.call("l1", emit_change=False), + mock.call("l2", emit_change=False), + mock.call("a1", emit_change=False), + mock.call("a2", emit_change=False), + mock.call("r1", emit_change=False), + mock.call("r2", emit_change=False), + ], m_handle.mock_calls) + + def handle_nl_msg(self, type_: str, result: EventResult, emit_change=True): + # msg should be a nlmsg, but dict is okay as long as we don't call + # cache_handle_nl_msg or build_event_data + msg = {"event": type_} + + msg_handler = self.listener.msg_handlers[type_] + + p_cache_handle = mock.patch.object(msg_handler, "cache_handle_nl_msg") + p_build = mock.patch.object(msg_handler, "build_event_data") + + with p_cache_handle as m_cache_handle, p_build as m_build: + m_cache_handle.return_value = result + + self.listener.handle_nl_msg(msg, emit_change=emit_change) + + return msg, msg_handler, m_cache_handle, m_build + + def test_handle_nl_msg__new(self): + res = self.handle_nl_msg("RTM_NEWLINK", result=EventResult.NEW) + msg, handler, m_cache_handle, m_build = res + + m_cache_handle.assert_called_once_with(msg) + m_build.assert_called_once_with(msg) + handler.observer_callback.assert_called_once_with( + "NEW", m_build()) + + def test_handle_nl_msg__delete(self): + res = self.handle_nl_msg("RTM_DELLINK", result=EventResult.DEL) + msg, handler, m_cache_handle, m_build = res + + m_cache_handle.assert_called_once_with(msg) + m_build.assert_called_once_with(msg) + handler.observer_callback.assert_called_once_with( + "DEL", m_build()) + + def test_handle_nl_msg__change(self): + res = self.handle_nl_msg("RTM_NEWLINK", result=EventResult.CHANGE) + msg, handler, m_cache_handle, m_build = res + + m_cache_handle.assert_called_once_with(msg) + m_build.assert_called_once_with(msg) + handler.observer_callback.assert_called_once_with( + "CHANGE", m_build()) + + def test_handle_nl_msg__change_no_emit(self): + res = self.handle_nl_msg("RTM_NEWLINK", result=EventResult.CHANGE, + emit_change=False) + msg, handler, m_cache_handle, m_build = res + + m_cache_handle.assert_called_once_with(msg) + m_build.assert_not_called() + handler.observer_callback.assert_not_called() + + def test_handle_nl_msg__discard(self): + res = self.handle_nl_msg("RTM_NEWLINK", result=EventResult.DISCARD) + msg, handler, m_cache_handle, m_build = res + + m_cache_handle.assert_called_once_with(msg) + m_build.assert_not_called() + handler.observer_callback.assert_not_called() + + def test_data_ready(self): + with mock.patch.object(self.listener, "handle_nl_msg") as m_handle: + with mock.patch.object(self.listener.ipr, "get", + return_value=["msg1", "msg2", "msg3"]): + self.listener.data_ready() + + self.assertEqual([ + mock.call("msg1"), mock.call("msg2"), mock.call("msg3")], + m_handle.mock_calls) + + def test_fileno(self): + # We need to patch the class, not the instance for some reason. + with mock.patch.object(IPRoute, "fileno", return_value=42): + self.assertEqual(42, self.listener.fileno()) + + def test_set_link_flags(self): + with mock.patch.object(self.listener.ipr, "link") as m_link: + self.listener.set_link_flags(ifindex=13, flags=IFF_UP) + + m_link.assert_called_once_with("set", index=13, flags=IFF_UP, + mask=IFF_UP) + + def test_unset_link_flags(self): + with mock.patch.object(self.listener.ipr, "link") as m_link: + self.listener.unset_link_flags(ifindex=13, flags=IFF_UP) + + m_link.assert_called_once_with("set", index=13, flags=0x0, mask=IFF_UP) diff --git a/probert/tests/rtnetlink/test_route.py b/probert/tests/rtnetlink/test_route.py new file mode 100644 index 0000000..fc43053 --- /dev/null +++ b/probert/tests/rtnetlink/test_route.py @@ -0,0 +1,159 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import socket +import unittest + +from pyroute2.iproute.ipmock import MockRoute +from pyroute2.netlink.rtnl import rtypes + +from probert.rtnetlink.route import RouteCache, build_event_data, get_ifindex +from probert.tests.rtnetlink.common import WithGetAttrMixin + + +class MyMockRoute(WithGetAttrMixin, MockRoute): + """Subclass of pyroute2's MockRoute that makes it behave like a nlmsg... + and works around a bug""" + def __init__(self, *args, **kwargs): + # Workaround ambiguity with route type. + # See https://github.com/svinota/pyroute2/pull/1409 + if "type" in kwargs: + raise ValueError("please use route_type instead of type") + if "route_type" in kwargs: + kwargs["type"] = kwargs["route_type"] + else: + kwargs["type"] = rtypes["RTN_UNICAST"] + super().__init__(*args, **kwargs) + + +class TestGetIfindex(unittest.TestCase): + def test_no_multipath(self): + self.assertEqual(554, get_ifindex(MyMockRoute(oif=554))) + + # TODO Ideally we want a test with multipath involved but + # MockRoute does not support it so... + + +class TestRouteBuildEventData(unittest.TestCase): + def test_route4(self): + msg = MyMockRoute(dst="192.168.1.0", dst_len=24, + family=socket.AF_INET.value, oif=4) + + self.assertEqual({ + "dst": b"192.168.1.0/24", + "family": socket.AF_INET.value, + "ifindex": 4, + "type": rtypes["RTN_UNICAST"], + "table": 254, + }, build_event_data(msg)) + + def test_default4(self): + msg = MyMockRoute(dst="0.0.0.0", dst_len=0, + family=socket.AF_INET.value, oif=5) + + self.assertEqual({ + "dst": b"default", + "family": socket.AF_INET.value, + "ifindex": 5, + "type": rtypes["RTN_UNICAST"], + "table": 254, + }, build_event_data(msg)) + + def test_default6(self): + msg = MyMockRoute(dst="::", dst_len=0, + family=socket.AF_INET6.value, oif=5) + + self.assertEqual({ + "dst": b"default", + "family": socket.AF_INET6.value, + "ifindex": 5, + "type": rtypes["RTN_UNICAST"], + "table": 254, + }, build_event_data(msg)) + + def test_multicast6(self): + msg = MyMockRoute(dst="ff00::", dst_len=8, + family=socket.AF_INET6.value, oif=6, + route_type=rtypes["RTN_MULTICAST"], table=255) + + self.assertEqual({ + "dst": b"ff00::/8", + "family": socket.AF_INET6.value, + "ifindex": 6, + "type": rtypes["RTN_MULTICAST"], + "table": 255, + }, build_event_data(msg)) + + +class TestRouteCache(unittest.TestCase): + def test_unique_identifier_from_nl_msg__with_tos(self): + self.assertEqual( + RouteCache.UniqueIdentifier( + family=socket.AF_INET.value, tos=44, + table=253, dst="1.1.1.0", prio=None), + RouteCache.UniqueIdentifier.from_nl_msg(MyMockRoute( + dst="1.1.1.0", dst_len=8, family=socket.AF_INET.value, + table=253, tos=44)) + ) + + def test_unique_identifier_from_nl_msg__with_priority(self): + self.assertEqual( + RouteCache.UniqueIdentifier( + family=socket.AF_INET6.value, tos=0, + table=254, dst="aaaa::", prio=30), + RouteCache.UniqueIdentifier.from_nl_msg(MyMockRoute( + dst="aaaa::", dst_len=72, family=socket.AF_INET6.value, + priority=30)) + ) + + def test_are_entries_equal__equal(self): + # Identical routes are considered equal + self.assertTrue(RouteCache.are_entries_equal( + MyMockRoute(dst="::", dst_len=0, + family=socket.AF_INET6.value, oif=5), + MyMockRoute(dst="::", dst_len=0, + family=socket.AF_INET6.value, oif=5), + )) + self.assertTrue(RouteCache.are_entries_equal( + MyMockRoute(dst="192.168.14.0", dst_len=24, + family=socket.AF_INET.value, oif=5, table=253), + MyMockRoute(dst="192.168.14.0", dst_len=24, + family=socket.AF_INET.value, oif=5, table=253), + )) + + # This is arguably a bug in libnl that we replicated but routes with + # different destlen are considered equal. + self.assertTrue(RouteCache.are_entries_equal( + MyMockRoute(dst="192.168.14.0", dst_len=32, + family=socket.AF_INET.value), + MyMockRoute(dst="192.168.14.0", dst_len=24, + family=socket.AF_INET.value), + )) + + def test_are_entries_equal__differ(self): + # destinations differ (these will be two separate cache entries though) + self.assertFalse(RouteCache.are_entries_equal( + MyMockRoute(index=3, family=socket.AF_INET.value, prefixlen=24, + dst="10.8.0.0"), + MyMockRoute(index=3, family=socket.AF_INET.value, prefixlen=24, + dst="10.8.1.0"), + )) + + # priorities differ (these will be two separate cache entries though) + self.assertFalse(RouteCache.are_entries_equal( + MyMockRoute(index=3, family=socket.AF_INET.value, prefixlen=24, + dst="10.8.0.0", priority=40), + MyMockRoute(index=3, family=socket.AF_INET.value, prefixlen=24, + dst="10.8.0.0", priority=10), + )) diff --git a/probert/tests/test_network.py b/probert/tests/test_network.py index 64762da..3e0032e 100644 --- a/probert/tests/test_network.py +++ b/probert/tests/test_network.py @@ -19,24 +19,9 @@ class TestUdevObserver(unittest.TestCase): - def test_init_no_nl80211(self): + def test_init(self): def fake_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == "probert" and "_nl80211" in fromlist: - raise ImportError - return orig_import(name) - - orig_import = __import__ - - with patch("builtins.__import__", side_effect=fake_import): - with self.assertRaises(ImportError): - UdevObserver(with_wlan_listener=True) - observer = UdevObserver(with_wlan_listener=False) - - self.assertIsNone(observer.wlan_listener) - - def test_init_with_nl80211(self): - def fake_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == "probert" and "_nl80211" in fromlist: + if name == "probert" and "nl80211" in fromlist: return Mock() return orig_import(name) diff --git a/probert/tests/test_nl80211.py b/probert/tests/test_nl80211.py new file mode 100644 index 0000000..c5a87a0 --- /dev/null +++ b/probert/tests/test_nl80211.py @@ -0,0 +1,30 @@ +# Copyright 2025 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +from probert import nl80211 + +import pyroute2 + + +class TestNlExceptToRuntimeError(unittest.TestCase): + def test_decorated_function(self): + @nl80211.nl_except_to_runtime_err("scanning wifi failed") + def f(): + # NetlinkDumpInterrupted uses code -1 + raise pyroute2.netlink.exceptions.NetlinkDumpInterrupted() + + with self.assertRaises(RuntimeError, msg="scanning wifi failed -1"): + f() diff --git a/setup.py b/setup.py index 0f5bae6..7672051 100644 --- a/setup.py +++ b/setup.py @@ -35,12 +35,6 @@ os.system('rm -rf probert.egg-info build dist') sys.exit() -def pkgconfig(package): - return { - 'extra_compile_args': subprocess.check_output(['pkg-config', '--cflags', package]).decode('utf8').split(), - 'extra_link_args': subprocess.check_output(['pkg-config', '--libs', package]).decode('utf8').split(), - } - def read_requirement(): return [req.strip() for req in open('requirements.txt')] @@ -55,17 +49,7 @@ def read_requirement(): url='https://github.com/canonical/probert', license="AGPLv3+", scripts=['bin/probert'], - ext_modules=[ - Extension( - "probert._rtnetlink", - ['probert/_rtnetlinkmodule.c'], - **pkgconfig("libnl-route-3.0")), - Extension( - "probert._nl80211", - ['probert/_nl80211module.c'], - **pkgconfig("libnl-genl-3.0")), - ], - packages=find_packages(), + packages=find_packages(exclude=["probert.tests*"]), install_requires=read_requirement(), - include_package_data=True, + include_package_data=False, ) diff --git a/test-requirements.txt b/test-requirements.txt index 5367d85..18e9ea2 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,4 +1,5 @@ flake8 parameterized +pyroute2 pytest pytest-cov