diff --git a/msgq/impl_msgq.cc b/msgq/impl_msgq.cc index b23991351..ff12efeed 100644 --- a/msgq/impl_msgq.cc +++ b/msgq/impl_msgq.cc @@ -54,7 +54,7 @@ int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string a assert(address == "127.0.0.1"); q = new msgq_queue_t; - int r = msgq_new_queue(q, endpoint.c_str(), DEFAULT_SEGMENT_SIZE); + int r = msgq_new_queue_sub(q, endpoint.c_str(), DEFAULT_SEGMENT_SIZE); if (r != 0){ return r; } @@ -147,7 +147,7 @@ int MSGQPubSocket::connect(Context *context, std::string endpoint, bool check_en //} q = new msgq_queue_t; - int r = msgq_new_queue(q, endpoint.c_str(), DEFAULT_SEGMENT_SIZE); + int r = msgq_new_queue_pub(q, endpoint.c_str(), DEFAULT_SEGMENT_SIZE); if (r != 0){ return r; } diff --git a/msgq/msgq.cc b/msgq/msgq.cc index 5ce25a3bc..800c90776 100644 --- a/msgq/msgq.cc +++ b/msgq/msgq.cc @@ -83,7 +83,18 @@ void msgq_wait_for_subscriber(msgq_queue_t *q){ return; } -int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){ +int msgq_new_queue_pub(msgq_queue_t * q, const char * path, size_t size){ + return msgq_new_queue(q, path, size, true); +} + +int msgq_new_queue_sub(msgq_queue_t * q, const char * path, size_t size){ + return msgq_new_queue(q, path, size, false); +} + +int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size, bool pub){ + size_t header_size = getpagesize(); + + assert(header_size >= sizeof(msgq_header_t)); assert(size < 0xFFFFFFFF); // Buffer must be smaller than 2^32 bytes std::signal(SIGUSR2, sigusr2_handler); @@ -100,20 +111,34 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){ return -1; } - int rc = ftruncate(fd, size + sizeof(msgq_header_t)); + int rc = ftruncate(fd, size + header_size); if (rc < 0){ close(fd); return -1; } - char * mem = (char*)mmap(NULL, size + sizeof(msgq_header_t), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - close(fd); - if (mem == MAP_FAILED){ + char *mem_header = (char*)mmap(NULL, header_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (mem_header == MAP_FAILED){ + close(fd); + return -1; + } + + int prot = PROT_READ; + if (pub) { + prot |= PROT_WRITE; + } + + char *mem_data = (char*)mmap(NULL, size, prot, MAP_SHARED, fd, header_size); + if (mem_data == MAP_FAILED){ + munmap(mem_header, header_size); + close(fd); return -1; } - q->mmap_p = mem; - msgq_header_t *header = (msgq_header_t *)mem; + close(fd); + q->mmap_p = mem_header; + + msgq_header_t *header = (msgq_header_t *)mem_header; // Setup pointers to header segment q->num_readers = reinterpret_cast*>(&header->num_readers); @@ -126,7 +151,7 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){ q->read_uids[i] = reinterpret_cast*>(&header->read_uids[i]); } - q->data = mem + sizeof(msgq_header_t); + q->data = mem_data; q->size = size; q->reader_id = -1; @@ -138,7 +163,11 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){ void msgq_close_queue(msgq_queue_t *q){ if (q->mmap_p != NULL){ - munmap(q->mmap_p, q->size + sizeof(msgq_header_t)); + size_t header_size = getpagesize(); + munmap(q->mmap_p, header_size); + } + if (q->data != NULL){ + munmap(q->data, q->size); } } diff --git a/msgq/msgq.h b/msgq/msgq.h index 94e184944..a867ac3b5 100644 --- a/msgq/msgq.h +++ b/msgq/msgq.h @@ -57,7 +57,9 @@ int msgq_msg_init_size(msgq_msg_t *msg, size_t size); int msgq_msg_init_data(msgq_msg_t *msg, char * data, size_t size); int msgq_msg_close(msgq_msg_t *msg); -int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size); +int msgq_new_queue_pub(msgq_queue_t * q, const char * path, size_t size); +int msgq_new_queue_sub(msgq_queue_t * q, const char * path, size_t size); +int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size, bool pub); void msgq_close_queue(msgq_queue_t *q); void msgq_init_publisher(msgq_queue_t * q); void msgq_init_subscriber(msgq_queue_t * q); diff --git a/msgq/msgq_tests.cc b/msgq/msgq_tests.cc index 02f17917a..3d69b65d3 100644 --- a/msgq/msgq_tests.cc +++ b/msgq/msgq_tests.cc @@ -1,3 +1,7 @@ +#include +#include +#include + #include "catch2/catch.hpp" #include "msgq/msgq.h" @@ -45,7 +49,7 @@ TEST_CASE("msgq_init_subscriber") { remove("/dev/shm/test_queue"); msgq_queue_t q; - msgq_new_queue(&q, "test_queue", 1024); + msgq_new_queue_sub(&q, "test_queue", 1024); REQUIRE(*q.num_readers == 0); q.reader_id = 1; @@ -65,7 +69,7 @@ TEST_CASE("msgq_msg_send first message") { remove("/dev/shm/test_queue"); msgq_queue_t q; - msgq_new_queue(&q, "test_queue", 1024); + msgq_new_queue_pub(&q, "test_queue", 1024); msgq_init_publisher(&q); REQUIRE(*q.write_pointer == 0); @@ -102,7 +106,7 @@ TEST_CASE("msgq_msg_send test wraparound") { remove("/dev/shm/test_queue"); msgq_queue_t q; - msgq_new_queue(&q, "test_queue", 1024); + msgq_new_queue_pub(&q, "test_queue", 1024); msgq_init_publisher(&q); REQUIRE((*q.write_pointer & 0xFFFFFFFF) == 0); @@ -134,8 +138,8 @@ TEST_CASE("msgq_msg_recv test wraparound") { remove("/dev/shm/test_queue"); msgq_queue_t q_pub, q_sub; - msgq_new_queue(&q_pub, "test_queue", 1024); - msgq_new_queue(&q_sub, "test_queue", 1024); + msgq_new_queue_pub(&q_pub, "test_queue", 1024); + msgq_new_queue_sub(&q_sub, "test_queue", 1024); msgq_init_publisher(&q_pub); msgq_init_subscriber(&q_sub); @@ -180,8 +184,8 @@ TEST_CASE("msgq_msg_send test invalidation") { remove("/dev/shm/test_queue"); msgq_queue_t q_pub, q_sub; - msgq_new_queue(&q_pub, "test_queue", 1024); - msgq_new_queue(&q_sub, "test_queue", 1024); + msgq_new_queue_pub(&q_pub, "test_queue", 1024); + msgq_new_queue_sub(&q_sub, "test_queue", 1024); msgq_init_publisher(&q_pub); msgq_init_subscriber(&q_sub); @@ -216,8 +220,8 @@ TEST_CASE("msgq_init_subscriber init 2 subscribers") { remove("/dev/shm/test_queue"); msgq_queue_t q1, q2; - msgq_new_queue(&q1, "test_queue", 1024); - msgq_new_queue(&q2, "test_queue", 1024); + msgq_new_queue_sub(&q1, "test_queue", 1024); + msgq_new_queue_sub(&q2, "test_queue", 1024); *q1.num_readers = 0; @@ -241,8 +245,44 @@ TEST_CASE("Write 1 msg, read 1 msg", "[integration]") const size_t msg_size = 128; msgq_queue_t writer, reader; - msgq_new_queue(&writer, "test_queue", 1024); - msgq_new_queue(&reader, "test_queue", 1024); + msgq_new_queue_pub(&writer, "test_queue", 1024); + msgq_new_queue_sub(&reader, "test_queue", 1024); + + msgq_init_publisher(&writer); + msgq_init_subscriber(&reader); + + // Build 128 byte message + msgq_msg_t outgoing_msg; + msgq_msg_init_size(&outgoing_msg, msg_size); + + for (size_t i = 0; i < msg_size; i++) + { + outgoing_msg.data[i] = i; + } + + REQUIRE(msgq_msg_send(&outgoing_msg, &writer) == msg_size); + + msgq_msg_t incoming_msg1; + REQUIRE(msgq_msg_recv(&incoming_msg1, &reader) == msg_size); + REQUIRE(memcmp(incoming_msg1.data, outgoing_msg.data, msg_size) == 0); + + // Verify that there are no more messages + msgq_msg_t incoming_msg2; + REQUIRE(msgq_msg_recv(&incoming_msg2, &reader) == 0); + + msgq_msg_close(&outgoing_msg); + msgq_msg_close(&incoming_msg1); + msgq_msg_close(&incoming_msg2); +} + +TEST_CASE("Write/read 1 msg, detect violate permission", "[integration]") +{ + remove("/dev/shm/test_queue"); + const size_t msg_size = 128; + msgq_queue_t writer, reader; + + msgq_new_queue_pub(&writer, "test_queue", 1024); + msgq_new_queue_sub(&reader, "test_queue", 1024); msgq_init_publisher(&writer); msgq_init_subscriber(&reader); @@ -266,6 +306,30 @@ TEST_CASE("Write 1 msg, read 1 msg", "[integration]") msgq_msg_t incoming_msg2; REQUIRE(msgq_msg_recv(&incoming_msg2, &reader) == 0); + // Wait SIGSEGV to detect write access from subscriber + pid_t pid = fork(); + if (pid != 0) { + // Parent: Wait SIGSEGV of the child + int status; + pid_t res = waitpid(pid, &status, 0); + REQUIRE(res == pid); + REQUIRE(WIFSIGNALED(status)); + REQUIRE(WTERMSIG(status) == SIGSEGV); + } else { + // Child: Remove CATCH2's signal handler and write + struct sigaction act; + act.sa_handler = SIG_DFL; + sigaction(SIGSEGV, &act, NULL); + // Try to write into read-only area + incoming_msg2.data[0] = 1; + exit(0); + } + + for (size_t i = 0; i < msg_size; i++) + { + REQUIRE(outgoing_msg.data[i] == i); + } + msgq_msg_close(&outgoing_msg); msgq_msg_close(&incoming_msg1); msgq_msg_close(&incoming_msg2); @@ -277,8 +341,8 @@ TEST_CASE("Write 2 msg, read 2 msg - conflate = false", "[integration]") const size_t msg_size = 128; msgq_queue_t writer, reader; - msgq_new_queue(&writer, "test_queue", 1024); - msgq_new_queue(&reader, "test_queue", 1024); + msgq_new_queue_pub(&writer, "test_queue", 1024); + msgq_new_queue_sub(&reader, "test_queue", 1024); msgq_init_publisher(&writer); msgq_init_subscriber(&reader); @@ -314,8 +378,8 @@ TEST_CASE("Write 2 msg, read 2 msg - conflate = true", "[integration]") const size_t msg_size = 128; msgq_queue_t writer, reader; - msgq_new_queue(&writer, "test_queue", 1024); - msgq_new_queue(&reader, "test_queue", 1024); + msgq_new_queue_pub(&writer, "test_queue", 1024); + msgq_new_queue_sub(&reader, "test_queue", 1024); msgq_init_publisher(&writer); msgq_init_subscriber(&reader); @@ -351,8 +415,8 @@ TEST_CASE("1 publisher, 1 slow subscriber", "[integration]") remove("/dev/shm/test_queue"); msgq_queue_t writer, reader; - msgq_new_queue(&writer, "test_queue", 1024); - msgq_new_queue(&reader, "test_queue", 1024); + msgq_new_queue_pub(&writer, "test_queue", 1024); + msgq_new_queue_sub(&reader, "test_queue", 1024); msgq_init_publisher(&writer); msgq_init_subscriber(&reader); @@ -394,9 +458,9 @@ TEST_CASE("1 publisher, 2 subscribers", "[integration]") remove("/dev/shm/test_queue"); msgq_queue_t writer, reader1, reader2; - msgq_new_queue(&writer, "test_queue", 1024); - msgq_new_queue(&reader1, "test_queue", 1024); - msgq_new_queue(&reader2, "test_queue", 1024); + msgq_new_queue_pub(&writer, "test_queue", 1024); + msgq_new_queue_sub(&reader1, "test_queue", 1024); + msgq_new_queue_sub(&reader2, "test_queue", 1024); msgq_init_publisher(&writer); msgq_init_subscriber(&reader1); diff --git a/msgq/visionipc/visionbuf_cl.cc b/msgq/visionipc/visionbuf_cl.cc index db1ca0334..e8a7481e1 100644 --- a/msgq/visionipc/visionbuf_cl.cc +++ b/msgq/visionipc/visionbuf_cl.cc @@ -52,7 +52,7 @@ void VisionBuf::init_cl(cl_device_id device_id, cl_context ctx){ void VisionBuf::import(){ assert(this->fd >= 0); - this->addr = mmap(NULL, this->mmap_len, PROT_READ | PROT_WRITE, MAP_SHARED, this->fd, 0); + this->addr = mmap(NULL, this->mmap_len, PROT_READ, MAP_SHARED, this->fd, 0); assert(this->addr != MAP_FAILED); this->frame_id = (uint64_t*)((uint8_t*)this->addr + this->len);