LCOV - code coverage report
Current view: top level - ipc/server/source - IpcServerImpl.cpp (source / functions) Coverage Total Hit
Test: coverage.info Lines: 62.7 % 565 354
Test Date: 2026-06-17 06:35:35 Functions: 86.2 % 29 25

            Line data    Source code
       1              : /*
       2              :  * If not stated otherwise in this file or this component's LICENSE file the
       3              :  * following copyright and licenses apply:
       4              :  *
       5              :  * Copyright 2022 Sky UK
       6              :  *
       7              :  * Licensed under the Apache License, Version 2.0 (the "License");
       8              :  * you may not use this file except in compliance with the License.
       9              :  * You may obtain a copy of the License at
      10              :  *
      11              :  * http://www.apache.org/licenses/LICENSE-2.0
      12              :  *
      13              :  * Unless required by applicable law or agreed to in writing, software
      14              :  * distributed under the License is distributed on an "AS IS" BASIS,
      15              :  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      16              :  * See the License for the specific language governing permissions and
      17              :  * limitations under the License.
      18              :  */
      19              : 
      20              : #include "IpcServerImpl.h"
      21              : #include "IIpcServerFactory.h"
      22              : #include "IpcClientImpl.h"
      23              : #include "IpcLogging.h"
      24              : #include "IpcServerControllerImpl.h"
      25              : 
      26              : #include "rialtoipc.pb.h"
      27              : 
      28              : #include <google/protobuf/service.h>
      29              : 
      30              : #include <algorithm>
      31              : #include <cinttypes>
      32              : #include <cstdarg>
      33              : #include <string>
      34              : #include <utility>
      35              : #include <vector>
      36              : 
      37              : #include <fcntl.h>
      38              : #include <poll.h>
      39              : #include <sys/epoll.h>
      40              : #include <sys/eventfd.h>
      41              : #include <sys/file.h>
      42              : #include <sys/socket.h>
      43              : #include <sys/stat.h>
      44              : #include <sys/un.h>
      45              : #include <unistd.h>
      46              : 
      47              : #define WAKE_EVENT_ID uint64_t(0)
      48              : #define FIRST_LISTENING_SOCKET_ID uint64_t(1)
      49              : #define FIRST_CLIENT_ID uint64_t(10000)
      50              : 
      51              : namespace firebolt::rialto::ipc
      52              : {
      53              : const size_t ServerImpl::m_kMaxMessageLen = (128 * 1024);
      54              : 
      55           32 : std::shared_ptr<IServerFactory> IServerFactory::createFactory()
      56              : {
      57           32 :     std::shared_ptr<IServerFactory> factory;
      58              :     try
      59              :     {
      60           32 :         factory = std::make_shared<ServerFactory>();
      61              :     }
      62            0 :     catch (const std::exception &e)
      63              :     {
      64            0 :         RIALTO_IPC_LOG_ERROR("Failed to create the server factory, reason: %s", e.what());
      65              :     }
      66              : 
      67           32 :     return factory;
      68              : }
      69              : 
      70           32 : std::shared_ptr<IServer> ServerFactory::create()
      71              : {
      72           32 :     return std::make_shared<ServerImpl>();
      73              : }
      74              : 
      75           32 : ServerImpl::ServerImpl()
      76           32 :     : m_pollFd(-1), m_wakeEventFd(-1), m_socketIdCounter(FIRST_LISTENING_SOCKET_ID), m_clientIdCounter(FIRST_CLIENT_ID),
      77           64 :       m_recvDataBuf{0}, m_recvCtrlBuf{0}
      78              : {
      79              :     // create the eventfd use to wake the poll loop
      80           32 :     m_wakeEventFd = eventfd(0, EFD_CLOEXEC);
      81           32 :     if (m_wakeEventFd < 0)
      82              :     {
      83            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd failed");
      84            0 :         return;
      85              :     }
      86              : 
      87              :     // create epoll loop
      88           32 :     m_pollFd = epoll_create1(EPOLL_CLOEXEC);
      89           32 :     if (m_pollFd < 0)
      90              :     {
      91            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_create1 failed");
      92            0 :         return;
      93              :     }
      94              : 
      95              :     // add the wake event to epoll
      96           32 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = WAKE_EVENT_ID}};
      97           32 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, m_wakeEventFd, &event) != 0)
      98              :     {
      99            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
     100              :     }
     101              : }
     102              : 
     103           32 : ServerImpl::~ServerImpl()
     104              : {
     105           32 :     if ((m_pollFd >= 0) && (close(m_pollFd) != 0))
     106            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close epoll");
     107              : 
     108           32 :     if ((m_wakeEventFd >= 0) && (close(m_wakeEventFd) != 0))
     109            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close eventfd");
     110              : 
     111           45 :     for (const auto &entry : m_sockets)
     112              :     {
     113           13 :         const Socket &kSocket = entry.second;
     114              : 
     115           13 :         if (kSocket.isOwned)
     116              :         {
     117            9 :             if (unlink(kSocket.sockPath.c_str()) != 0)
     118            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", kSocket.sockPath.c_str());
     119            9 :             if (close(kSocket.sockFd) != 0)
     120            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
     121              : 
     122            9 :             if (unlink(kSocket.lockPath.c_str()) != 0)
     123            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", kSocket.lockPath.c_str());
     124            9 :             if (close(kSocket.lockFd) != 0)
     125            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
     126              :         }
     127              :     }
     128           32 : }
     129              : 
     130              : // -----------------------------------------------------------------------------
     131              : /*!
     132              :     \static
     133              :     \internal
     134              : 
     135              :     Creates (if required) and takes the file lock associated with the socket
     136              :     path in the \a socket object.
     137              : 
     138              :  */
     139            9 : bool ServerImpl::getSocketLock(Socket *socket)
     140              : {
     141            9 :     std::string lockPath = socket->sockPath + ".lock";
     142            9 :     int fileDescriptor = open(lockPath.c_str(), O_CREAT | O_CLOEXEC, (S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP));
     143            9 :     if (fileDescriptor < 0)
     144              :     {
     145            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create / open lockfile @ '%s' (check permissions)", lockPath.c_str());
     146            0 :         return false;
     147              :     }
     148              : 
     149            9 :     if (flock(fileDescriptor, LOCK_EX | LOCK_NB) < 0)
     150              :     {
     151            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to lock lockfile @ '%s', maybe another server is running",
     152              :                                  lockPath.c_str());
     153            0 :         close(fileDescriptor);
     154            0 :         return false;
     155              :     }
     156              : 
     157            9 :     struct stat sbuf = {0};
     158            9 :     if (stat(socket->sockPath.c_str(), &sbuf) < 0)
     159              :     {
     160            9 :         if (errno != ENOENT)
     161              :         {
     162            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "did not manage to stat existing socket @ '%s'", socket->sockPath.c_str());
     163            0 :             close(fileDescriptor);
     164            0 :             return false;
     165              :         }
     166              :     }
     167            0 :     else if ((sbuf.st_mode & S_IWUSR) || (sbuf.st_mode & S_IWGRP))
     168              :     {
     169            0 :         unlink(socket->sockPath.c_str());
     170              :     }
     171              : 
     172            9 :     socket->lockFd = fileDescriptor;
     173            9 :     socket->lockPath = std::move(lockPath);
     174              : 
     175            9 :     return true;
     176              : }
     177              : 
     178              : // -----------------------------------------------------------------------------
     179              : /*!
     180              :     \static
     181              :     \internal
     182              : 
     183              :     Closes the file descriptors and the deletes the files stored in the \a socket
     184              :     object.
     185              : 
     186              :  */
     187            0 : void ServerImpl::closeListeningSocket(Socket *socket)
     188              : {
     189            0 :     if (!socket->isOwned)
     190              :     {
     191            0 :         return;
     192              :     }
     193              : 
     194            0 :     if (!socket->sockPath.empty() && (unlink(socket->sockPath.c_str()) != 0) && (errno != ENOENT))
     195            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", socket->sockPath.c_str());
     196            0 :     if ((socket->sockFd >= 0) && (close(socket->sockFd) != 0))
     197            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
     198              : 
     199            0 :     if (!socket->lockPath.empty() && (unlink(socket->lockPath.c_str()) != 0) && (errno != ENOENT))
     200            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", socket->lockPath.c_str());
     201            0 :     if ((socket->lockFd >= 0) && (close(socket->lockFd) != 0))
     202            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
     203              : 
     204            0 :     socket->sockFd = -1;
     205            0 :     socket->sockPath.clear();
     206              : 
     207            0 :     socket->lockFd = -1;
     208            0 :     socket->lockPath.clear();
     209              : }
     210              : 
     211            9 : bool ServerImpl::addSocket(const std::string &socketPath,
     212              :                            const std::function<void(const std::shared_ptr<IClient> &)> &clientConnectedCb,
     213              :                            const std::function<void(const std::shared_ptr<IClient> &)> &clientDisconnectedCb)
     214              : {
     215              :     // store the path
     216            9 :     Socket socket;
     217            9 :     socket.sockPath = socketPath;
     218              : 
     219              :     // create the socket
     220            9 :     socket.sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
     221            9 :     if (socket.sockFd == -1)
     222              :     {
     223            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
     224            0 :         return false;
     225              :     }
     226              : 
     227              :     // get the socket lock
     228            9 :     if (!getSocketLock(&socket))
     229              :     {
     230            0 :         closeListeningSocket(&socket);
     231            0 :         return false;
     232              :     }
     233              : 
     234              :     // bind to the given path
     235            9 :     struct sockaddr_un addr = {0};
     236            9 :     memset(&addr, 0x00, sizeof(addr));
     237            9 :     addr.sun_family = AF_UNIX;
     238            9 :     strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
     239              : 
     240            9 :     if (bind(socket.sockFd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
     241              :     {
     242            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
     243              : 
     244            0 :         closeListeningSocket(&socket);
     245            0 :         return false;
     246              :     }
     247              : 
     248              :     // put in listening mode
     249            9 :     if (listen(socket.sockFd, 1) == -1)
     250              :     {
     251            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
     252              : 
     253            0 :         closeListeningSocket(&socket);
     254            0 :         return false;
     255              :     }
     256              : 
     257              :     // create an id for the listening socket
     258            9 :     const uint64_t kSocketId = m_socketIdCounter++;
     259            9 :     if (kSocketId >= FIRST_CLIENT_ID)
     260              :     {
     261              :         // should never happen, we'd run out of file descriptors before
     262              :         // we hit the 10k limit on listening sockets
     263            0 :         RIALTO_IPC_LOG_ERROR("too many listening sockets");
     264              : 
     265            0 :         closeListeningSocket(&socket);
     266            0 :         return false;
     267              :     }
     268              : 
     269              :     // add the socket to epoll
     270            9 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
     271            9 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
     272              :     {
     273            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
     274              : 
     275            0 :         closeListeningSocket(&socket);
     276            0 :         return false;
     277              :     }
     278              : 
     279              :     // store the client connected / disconnected callbacks
     280            9 :     socket.connectedCb = clientConnectedCb;
     281            9 :     socket.disconnectedCb = clientDisconnectedCb;
     282              : 
     283              :     // add to the internal map
     284            9 :     std::lock_guard<std::mutex> locker(m_socketsLock);
     285            9 :     m_sockets.emplace(kSocketId, std::move(socket));
     286              : 
     287            9 :     RIALTO_IPC_LOG_INFO("added listening socket '%s' to server", socketPath.c_str());
     288              : 
     289            9 :     return true;
     290              : }
     291              : 
     292            4 : bool ServerImpl::addSocket(int fd, const std::function<void(const std::shared_ptr<IClient> &)> &clientConnectedCb,
     293              :                            const std::function<void(const std::shared_ptr<IClient> &)> &clientDisconnectedCb)
     294              : {
     295              :     // store the path
     296            4 :     Socket socket;
     297            4 :     socket.isOwned = false;
     298            4 :     socket.sockFd = fd;
     299              : 
     300              :     // put in listening mode
     301            4 :     if (listen(socket.sockFd, 1) == -1)
     302              :     {
     303            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
     304            0 :         return false;
     305              :     }
     306              : 
     307              :     // create an id for the listening socket
     308            4 :     const uint64_t kSocketId = m_socketIdCounter++;
     309            4 :     if (kSocketId >= FIRST_CLIENT_ID)
     310              :     {
     311              :         // should never happen, we'd run out of file descriptors before
     312              :         // we hit the 10k limit on listening sockets
     313            0 :         RIALTO_IPC_LOG_ERROR("too many listening sockets");
     314            0 :         return false;
     315              :     }
     316              : 
     317              :     // add the socket to epoll
     318            4 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
     319            4 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
     320              :     {
     321            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
     322            0 :         return false;
     323              :     }
     324              : 
     325              :     // store the client connected / disconnected callbacks
     326            4 :     socket.connectedCb = clientConnectedCb;
     327            4 :     socket.disconnectedCb = clientDisconnectedCb;
     328              : 
     329              :     // add to the internal map
     330            4 :     std::lock_guard<std::mutex> locker(m_socketsLock);
     331            4 :     m_sockets.emplace(kSocketId, std::move(socket));
     332              : 
     333            4 :     RIALTO_IPC_LOG_INFO("added listening socket with fd: %d to server", fd);
     334              : 
     335            4 :     return true;
     336              : }
     337              : 
     338           16 : std::shared_ptr<IClient> ServerImpl::addClient(int socketFd,
     339              :                                                std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
     340              : {
     341              :     // sanity check the supplied socket is of the right type
     342              :     struct sockaddr addr;
     343           16 :     socklen_t len = sizeof(sockaddr);
     344           16 :     if ((getsockname(socketFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
     345              :     {
     346            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
     347            0 :         return nullptr;
     348              :     }
     349           16 :     if (addr.sa_family != AF_UNIX)
     350              :     {
     351            0 :         RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
     352            0 :         return nullptr;
     353              :     }
     354              : 
     355           16 :     int type = 0;
     356           16 :     len = sizeof(type);
     357           16 :     if ((getsockopt(socketFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
     358              :     {
     359            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
     360            0 :         return nullptr;
     361              :     }
     362           16 :     if (type != SOCK_SEQPACKET)
     363              :     {
     364            0 :         RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
     365            0 :         return nullptr;
     366              :     }
     367              : 
     368              :     // dup the socket and set the O_CLOEXEC bit
     369           16 :     int duppedFd = fcntl(socketFd, F_DUPFD_CLOEXEC, 3);
     370           16 :     if (duppedFd < 0)
     371              :     {
     372            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
     373            0 :         return nullptr;
     374              :     }
     375              : 
     376              :     // ensure the SOCK_NONBLOCK flag is set
     377           16 :     int flags = fcntl(duppedFd, F_GETFL);
     378           16 :     if ((flags < 0) || (fcntl(duppedFd, F_SETFL, flags | O_NONBLOCK) < 0))
     379              :     {
     380            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
     381            0 :         close(duppedFd);
     382            0 :         return nullptr;
     383              :     }
     384              : 
     385              :     // finally, add the socket to the list of clients
     386           48 :     auto client = addClientSocket(duppedFd, "", std::move(clientDisconnectedCb));
     387           16 :     if (!client)
     388              :     {
     389            0 :         close(duppedFd);
     390              :     }
     391              : 
     392           16 :     return client;
     393              : }
     394              : 
     395            0 : int ServerImpl::fd() const
     396              : {
     397            0 :     return m_pollFd;
     398              : }
     399              : 
     400              : // -----------------------------------------------------------------------------
     401              : /*!
     402              :     \internal
     403              : 
     404              :     Writes to the eventfd to wake the event loop.  Typically called when requesting
     405              :     it to shutdown or external code has requested that a client be disconnected.
     406              : 
     407              :  */
     408           29 : void ServerImpl::wakeEventLoop() const
     409              : {
     410           29 :     if (m_wakeEventFd < 0)
     411              :     {
     412            0 :         RIALTO_IPC_LOG_ERROR("invalid wake event fd");
     413              :     }
     414              :     else
     415              :     {
     416           29 :         uint64_t value = 1;
     417           29 :         if (TEMP_FAILURE_RETRY(::write(m_wakeEventFd, &value, sizeof(value))) != sizeof(value))
     418              :         {
     419            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to the event fd");
     420              :         }
     421              :     }
     422           29 : }
     423              : 
     424         2891 : bool ServerImpl::wait(int timeoutMSecs)
     425              : {
     426         2891 :     if (m_pollFd < 0)
     427              :     {
     428            0 :         return false;
     429              :     }
     430              : 
     431              :     // wait for any event (with timeout)
     432              :     struct pollfd fds[2];
     433         2891 :     fds[0].fd = m_pollFd;
     434         2891 :     fds[0].events = POLLIN;
     435              : 
     436         2891 :     int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
     437         2891 :     if (rc < 0)
     438              :     {
     439            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
     440            0 :         return false;
     441              :     }
     442              : 
     443         2891 :     return true;
     444              : }
     445              : 
     446         2920 : bool ServerImpl::process()
     447              : {
     448         2920 :     if (m_pollFd < 0)
     449              :     {
     450            0 :         RIALTO_IPC_LOG_ERROR("missing epoll");
     451            0 :         return false;
     452              :     }
     453              : 
     454              :     // read up to 32 events
     455         2920 :     const int kMaxEvents = 32;
     456              :     struct epoll_event events[kMaxEvents];
     457              : 
     458         2920 :     int rc = TEMP_FAILURE_RETRY(epoll_wait(m_pollFd, events, kMaxEvents, 0));
     459         2920 :     if (rc < 0)
     460              :     {
     461            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
     462            0 :         return false;
     463              :     }
     464              : 
     465              :     // process the events (maybe 0 if timed out)
     466         2985 :     for (int i = 0; i < rc; i++)
     467              :     {
     468           65 :         const struct epoll_event &kEvent = events[i];
     469              : 
     470              :         // check if a wake event, in which case just clear the eventfd
     471           65 :         if (kEvent.data.u64 == WAKE_EVENT_ID)
     472              :         {
     473              :             uint64_t ignore;
     474            3 :             if (TEMP_FAILURE_RETRY(::read(m_wakeEventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
     475              :             {
     476            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to read wake eventfd");
     477              :             }
     478              :         }
     479              : 
     480              :         // check for events on the listening socket
     481           62 :         else if (kEvent.data.u64 < FIRST_CLIENT_ID)
     482              :         {
     483           13 :             if (kEvent.events & EPOLLIN)
     484           13 :                 processNewConnection(kEvent.data.u64);
     485           13 :             if (kEvent.events & EPOLLERR)
     486            0 :                 RIALTO_IPC_LOG_ERROR("error occurred on listening socket");
     487              :         }
     488              : 
     489              :         // otherwise, the event must have come from a socket
     490              :         else
     491              :         {
     492           49 :             processClientSocket(kEvent.data.u64, kEvent.events);
     493              :         }
     494              :     }
     495              : 
     496              :     // if we have client sockets that are condemned then we need to shut down
     497              :     // and close them as well as remove from epoll
     498         2920 :     std::unique_lock<std::mutex> locker(m_clientsLock);
     499         2920 :     if (!m_condemnedClients.empty())
     500              :     {
     501              :         // take a copy of the set so we can process without the lock held
     502           29 :         std::set<uint64_t> theCondemned;
     503           29 :         m_condemnedClients.swap(theCondemned);
     504              : 
     505           58 :         for (uint64_t clientId : theCondemned)
     506              :         {
     507           29 :             auto it = m_clients.find(clientId);
     508           29 :             if (it == m_clients.end())
     509              :             {
     510            0 :                 RIALTO_IPC_LOG_ERROR("failed to find condemned client");
     511            0 :                 continue;
     512              :             }
     513              : 
     514           29 :             ClientDetails details = it->second;
     515              : 
     516              :             // remove from the list of clients
     517           29 :             m_clients.erase(it);
     518              : 
     519              :             // drop the lock while closing the connection and removing from epoll
     520           29 :             locker.unlock();
     521              : 
     522              :             // remove the socket from epoll and close it
     523           29 :             if (details.sock >= 0)
     524              :             {
     525           29 :                 if (epoll_ctl(m_pollFd, EPOLL_CTL_DEL, details.sock, nullptr) != 0)
     526            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket from epoll");
     527              : 
     528           29 :                 if (shutdown(details.sock, SHUT_RDWR) != 0)
     529            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to shutdown socket");
     530           29 :                 if (close(details.sock) != 0)
     531            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket");
     532              :             }
     533              : 
     534              :             // let the installed handler know a client has disconnected
     535           29 :             if (details.disconnectedCb)
     536           29 :                 details.disconnectedCb(details.client);
     537              : 
     538              :             // ensure client object is destructed without the client lock held
     539           29 :             details.client.reset();
     540              : 
     541              :             // re-take the lock for the next client
     542           29 :             locker.lock();
     543              :         }
     544              :     }
     545              : 
     546         2920 :     return true;
     547              : }
     548              : 
     549              : // -----------------------------------------------------------------------------
     550              : /*!
     551              :     \internal
     552              : 
     553              :     Adds the \a socketFd to the internal list of client sockets.  This is called
     554              :     when a new connection is accepted on a listening socket, or when a client
     555              :     fd is added via ServerImpl::addClient(...).
     556              : 
     557              :     Returns a nullptr if failed to add the socket.
     558              : 
     559              :  */
     560           29 : std::shared_ptr<ClientImpl> ServerImpl::addClientSocket(int socketFd, const std::string &listeningSocketPath,
     561              :                                                         std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb)
     562              : {
     563              :     // get the client credentials
     564           29 :     struct ucred clientCreds = {0};
     565           29 :     socklen_t clientCredsLen = sizeof(clientCreds);
     566              : 
     567           58 :     if ((getsockopt(socketFd, SOL_SOCKET, SO_PEERCRED, &clientCreds, &clientCredsLen) < 0) ||
     568           29 :         (clientCredsLen != sizeof(clientCreds)))
     569              :     {
     570            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get client's details");
     571            0 :         return nullptr;
     572              :     }
     573              : 
     574              :     // create a new unique client id for the connection
     575           29 :     const uint64_t kClientId = m_clientIdCounter++;
     576              : 
     577              :     // add the new socket to the poll loop
     578           29 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kClientId}};
     579           29 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socketFd, &event) != 0)
     580              :     {
     581            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add client socket");
     582            0 :         return nullptr;
     583              :     }
     584              : 
     585              :     // create initial client object for the socket
     586           29 :     auto client = std::make_shared<ClientImpl>(shared_from_this(), kClientId, clientCreds);
     587              : 
     588              :     // and the details for the internal list
     589           29 :     ClientDetails clientDetails;
     590           29 :     clientDetails.sock = socketFd;
     591           29 :     clientDetails.disconnectedCb = std::move(disconnectedCb);
     592           29 :     clientDetails.client = client;
     593              : 
     594              :     // add to the set of clients
     595              :     {
     596           29 :         std::lock_guard<std::mutex> locker(m_clientsLock);
     597           29 :         m_clients.emplace(kClientId, clientDetails);
     598              :     }
     599              : 
     600           29 :     RIALTO_IPC_LOG_INFO("new client connected - giving id %" PRIu64, kClientId);
     601              : 
     602           29 :     return client;
     603              : }
     604              : 
     605              : // -----------------------------------------------------------------------------
     606              : /*!
     607              :     \internal
     608              : 
     609              :     Called when an event occurs on the listening socket. The code first accepts
     610              :     the connection and retrieves the client details, it then calls the installed
     611              :     handler to determine if we should drop this connection or not.
     612              : 
     613              : 
     614              :  */
     615           13 : void ServerImpl::processNewConnection(uint64_t socketId)
     616              : {
     617           13 :     RIALTO_IPC_LOG_DEBUG("processing new connection");
     618              : 
     619              :     int clientSock;
     620           13 :     std::string sockPath;
     621           13 :     std::function<void(const std::shared_ptr<IClient> &)> connectedCb;
     622           13 :     std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
     623              : 
     624              :     {
     625           13 :         std::unique_lock<std::mutex> socketLocker(m_socketsLock);
     626              : 
     627              :         // find matching socket object
     628           13 :         auto it = m_sockets.find(socketId);
     629           13 :         if (it == m_sockets.end())
     630              :         {
     631            0 :             RIALTO_IPC_LOG_ERROR("failed to find listening socket with id %" PRIu64, socketId);
     632            0 :             return;
     633              :         }
     634              : 
     635           13 :         const Socket &kSocket = it->second;
     636              : 
     637              :         // accept the connection from the client
     638           13 :         struct sockaddr clientAddr = {0};
     639           13 :         socklen_t clientAddrLen = sizeof(clientAddr);
     640              : 
     641           13 :         clientSock = accept4(kSocket.sockFd, &clientAddr, &clientAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC);
     642           13 :         if (clientSock < 0)
     643              :         {
     644            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to accept client connection");
     645            0 :             return;
     646              :         }
     647              : 
     648           13 :         sockPath = kSocket.sockPath;
     649           13 :         connectedCb = kSocket.connectedCb;
     650           13 :         disconnectedCb = kSocket.disconnectedCb;
     651              :     }
     652              : 
     653              :     // attempt to add the socket to the client list
     654           13 :     auto client = addClientSocket(clientSock, sockPath, std::move(disconnectedCb));
     655           13 :     if (!client)
     656              :     {
     657            0 :         close(clientSock);
     658            0 :         return;
     659              :     }
     660              : 
     661              :     // notify the handler that a new connection has been made
     662           13 :     if (connectedCb)
     663              :     {
     664              :         // tell the handler we have a new client
     665           13 :         connectedCb(client);
     666              :     }
     667              : }
     668              : 
     669              : // -----------------------------------------------------------------------------
     670              : /*!
     671              :     \internal
     672              :     \static
     673              : 
     674              :     Reads all the file descriptors from a message. It returns the file descriptors
     675              :     as a vector of FileDescriptor objects, these objects safely store the fd and
     676              :     close them when they're destructed.
     677              : 
     678              :  */
     679            0 : static std::vector<FileDescriptor> readMessageFds(const struct msghdr *msg, size_t limit)
     680              : {
     681            0 :     std::vector<FileDescriptor> fds;
     682              : 
     683            0 :     for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr;
     684            0 :          cmsg = CMSG_NXTHDR(const_cast<struct msghdr *>(msg), cmsg))
     685              :     {
     686            0 :         if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
     687              :         {
     688            0 :             const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
     689            0 :             if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
     690              :             {
     691            0 :                 RIALTO_IPC_LOG_ERROR("invalid fd array size");
     692              :             }
     693              :             else
     694              :             {
     695            0 :                 const size_t n = kFdsLength / sizeof(int);
     696            0 :                 RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
     697              : 
     698            0 :                 fds.reserve(std::min(limit, n));
     699              : 
     700            0 :                 const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
     701            0 :                 for (size_t i = 0; i < n; i++)
     702              :                 {
     703            0 :                     RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
     704              : 
     705            0 :                     if (fds.size() >= limit)
     706              :                     {
     707            0 :                         RIALTO_IPC_LOG_ERROR("received to many file descriptors, "
     708              :                                              "exceeding max per message, closing left overs");
     709              :                     }
     710              :                     else
     711              :                     {
     712            0 :                         FileDescriptor fd(kFds[i]);
     713            0 :                         if (!fd.isValid())
     714              :                         {
     715            0 :                             RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
     716              :                         }
     717              :                         else
     718              :                         {
     719            0 :                             fds.emplace_back(std::move(fd));
     720              :                         }
     721              :                     }
     722              : 
     723            0 :                     if (close(kFds[i]) != 0)
     724              :                     {
     725            0 :                         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
     726              :                     }
     727              :                 }
     728              :             }
     729              :         }
     730              :     }
     731              : 
     732            0 :     return fds;
     733              : }
     734              : 
     735              : // -----------------------------------------------------------------------------
     736              : /*!
     737              :     \internal
     738              : 
     739              :     Processes an event from a client socket.
     740              : 
     741              :  */
     742           49 : void ServerImpl::processClientSocket(uint64_t clientId, unsigned events)
     743              : {
     744              :     int sockFd;
     745           49 :     std::shared_ptr<ClientImpl> clientObj;
     746              : 
     747              :     // take the lock while accessing the client list
     748              :     {
     749           49 :         std::unique_lock<std::mutex> locker(m_clientsLock);
     750              : 
     751           49 :         auto it = m_clients.find(clientId);
     752           49 :         if (it == m_clients.end())
     753              :         {
     754              :             // should never happen
     755            0 :             RIALTO_IPC_LOG_ERROR("received an event from a socket with no matching client");
     756            0 :             return;
     757              :         }
     758              : 
     759              :         // check if the client is marked for closure, if so then just ignore the data
     760           49 :         if (m_condemnedClients.count(clientId) != 0)
     761              :         {
     762            0 :             return;
     763              :         }
     764              : 
     765              :         // get the socket that corresponds to the client connection
     766           49 :         sockFd = it->second.sock;
     767              : 
     768              :         // get the client object
     769           49 :         clientObj = it->second.client;
     770              :     }
     771              :     // lock is released here automatically
     772              : 
     773              :     // if there was an error disconnect the socket
     774           49 :     if (events & EPOLLERR)
     775              :     {
     776            0 :         RIALTO_IPC_LOG_ERROR("error detected on client socket - disconnecting client");
     777            0 :         disconnectClient(clientId);
     778            0 :         return;
     779              :     }
     780              : 
     781           49 :     if (events & EPOLLIN)
     782              :     {
     783              :         // read all messages from the client socket, we break out if the socket is closed
     784              :         // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
     785              :         while (true)
     786              :         {
     787           70 :             struct msghdr msg = {nullptr};
     788           70 :             struct iovec io = {.iov_base = m_recvDataBuf, .iov_len = sizeof(m_recvDataBuf)};
     789              : 
     790           70 :             bzero(&msg, sizeof(msg));
     791           70 :             msg.msg_iov = &io;
     792           70 :             msg.msg_iovlen = 1;
     793           70 :             msg.msg_control = m_recvCtrlBuf;
     794           70 :             msg.msg_controllen = sizeof(m_recvCtrlBuf);
     795              : 
     796              :             // read one message
     797           70 :             ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(sockFd, &msg, MSG_CMSG_CLOEXEC));
     798           70 :             if (rd < 0)
     799              :             {
     800           21 :                 if (errno != EWOULDBLOCK)
     801              :                 {
     802            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
     803            0 :                     disconnectClient(clientId);
     804              :                 }
     805              : 
     806           49 :                 break;
     807              :             }
     808           49 :             else if (rd == 0)
     809              :             {
     810              :                 // client closed connection, and we've read all data, add to the condemned set
     811              :                 // so is cleaned up once all the events are processed
     812           28 :                 disconnectClient(clientId);
     813              : 
     814           28 :                 break;
     815              :             }
     816           21 :             else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
     817              :             {
     818            0 :                 RIALTO_IPC_LOG_WARN("received message from client %" PRIu64 " truncated, discarding", clientId);
     819              : 
     820              :                 // make sure to close all the fds, otherwise we'll leak them
     821            0 :                 readMessageFds(&msg, 16);
     822              :             }
     823              :             else
     824              :             {
     825              :                 // if there is control data then assume fd(s) have been passed
     826           21 :                 if (msg.msg_controllen > 0)
     827              :                 {
     828            0 :                     processClientMessage(clientObj, m_recvDataBuf, rd, readMessageFds(&msg, 16));
     829              :                 }
     830              :                 else
     831              :                 {
     832           21 :                     processClientMessage(clientObj, m_recvDataBuf, rd);
     833              :                 }
     834              :             }
     835              :         }
     836              :     }
     837           49 : }
     838              : 
     839              : // -----------------------------------------------------------------------------
     840              : /*!
     841              :     \internal
     842              : 
     843              :     Places the received file descriptors into the request message.
     844              : 
     845              :     It works by iterating over the fields in the message, finding ones that are
     846              :     marked as 'field_is_fd' and then replacing the received integer value with
     847              :     an actual file descriptor.
     848              : 
     849              :  */
     850           21 : static bool addRequestFileDescriptors(google::protobuf::Message *request, const std::vector<FileDescriptor> &requestFds)
     851              : {
     852           21 :     auto fdIterator = requestFds.begin();
     853              : 
     854           21 :     const google::protobuf::Descriptor *kDescriptor = request->GetDescriptor();
     855           21 :     const google::protobuf::Reflection *kReflection = nullptr;
     856              : 
     857           21 :     const int n = kDescriptor->field_count();
     858           79 :     for (int i = 0; i < n; i++)
     859              :     {
     860           58 :         auto fieldDescriptor = kDescriptor->field(i);
     861           58 :         if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
     862              :         {
     863            0 :             if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
     864              :             {
     865            0 :                 RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
     866            0 :                 return false;
     867              :             }
     868              : 
     869            0 :             if (!kReflection)
     870              :             {
     871            0 :                 kReflection = request->GetReflection();
     872              :             }
     873              : 
     874            0 :             if (kReflection->HasField(*request, fieldDescriptor))
     875              :             {
     876            0 :                 if (fdIterator == requestFds.end())
     877              :                 {
     878            0 :                     RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but one was supplied");
     879            0 :                     return false;
     880              :                 }
     881              : 
     882            0 :                 kReflection->SetInt32(request, fieldDescriptor, fdIterator->fd());
     883            0 :                 ++fdIterator;
     884              :             }
     885              :         }
     886              :     }
     887              : 
     888           21 :     if (fdIterator != requestFds.end())
     889              :     {
     890            0 :         RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
     891            0 :         return false;
     892              :     }
     893              : 
     894           21 :     return true;
     895              : }
     896              : 
     897              : // -----------------------------------------------------------------------------
     898              : /*!
     899              :     \internal
     900              : 
     901              :     Processes a message received on a client socket.
     902              : 
     903              :  */
     904           21 : void ServerImpl::processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
     905              :                                       const std::vector<FileDescriptor> &fds)
     906              : {
     907           21 :     RIALTO_IPC_LOG_DEBUG("processing client message of size %zu bytes (%zu fds) from client %" PRId64, dataLen,
     908              :                          fds.size(), client->id());
     909              : 
     910              :     // parse the message
     911           21 :     transport::MessageToServer message;
     912           21 :     if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
     913              :     {
     914            0 :         RIALTO_IPC_LOG_ERROR("invalid request");
     915            0 :         return;
     916              :     }
     917              : 
     918           21 :     if (message.has_call())
     919              :     {
     920           21 :         processMethodCall(client, message.call(), fds);
     921              :     }
     922              :     else
     923              :     {
     924            0 :         RIALTO_IPC_LOG_WARN("received unknown message type from client");
     925              :     }
     926           21 : }
     927              : 
     928              : // -----------------------------------------------------------------------------
     929              : /*!
     930              :     \internal
     931              : 
     932              :     Processes a method call requst from a client.
     933              : 
     934              :  */
     935           21 : void ServerImpl::processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
     936              :                                    const std::vector<FileDescriptor> &fds)
     937              : {
     938              :     // try and find the service with the given name
     939           21 :     const std::string &kServiceName = call.service_name();
     940           21 :     auto it = client->m_services.find(kServiceName);
     941           21 :     if (it == client->m_services.end())
     942              :     {
     943            0 :         RIALTO_IPC_LOG_ERROR("unknown service request '%s'", kServiceName.c_str());
     944              : 
     945            0 :         sendErrorReply(client, call.serial_id(), "Unknown service '%s'", kServiceName.c_str());
     946            0 :         return;
     947              :     }
     948              : 
     949           21 :     std::shared_ptr<google::protobuf::Service> service = it->second;
     950              : 
     951              :     // try and find the method
     952           21 :     const std::string &kMethodName = call.method_name();
     953           21 :     const google::protobuf::MethodDescriptor *kMethod = service->GetDescriptor()->FindMethodByName(kMethodName);
     954           21 :     if (!kMethod)
     955              :     {
     956            0 :         RIALTO_IPC_LOG_ERROR("no method with name '%s'", kMethodName.c_str());
     957              : 
     958            0 :         sendErrorReply(client, call.serial_id(), "Unknown method '%s'", kMethodName.c_str());
     959            0 :         return;
     960              :     }
     961              : 
     962              :     // check if the method is expecting a reply
     963           21 :     const bool kNoReply = kMethod->options().HasExtension(no_reply) && kMethod->options().GetExtension(no_reply);
     964              : 
     965              :     // parse the request data
     966           21 :     google::protobuf::Message *requestMessage = service->GetRequestPrototype(kMethod).New();
     967           21 :     if (!requestMessage->ParseFromString(call.request_message()))
     968              :     {
     969            0 :         RIALTO_IPC_LOG_ERROR("failed to parse method from array");
     970              :     }
     971           21 :     else if (!addRequestFileDescriptors(requestMessage, fds))
     972              :     {
     973            0 :         RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the request");
     974              :     }
     975              :     else
     976              :     {
     977           21 :         RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", call.serial_id(), kServiceName.c_str(),
     978              :                              kMethodName.c_str(), requestMessage->ShortDebugString().c_str());
     979              : 
     980           21 :         auto *controller = new ServerControllerImpl(client, call.serial_id());
     981              : 
     982           21 :         if (kNoReply)
     983              :         {
     984              :             // we should not send a reply for this call, so call the code to handle the
     985              :             // request, but no need to pass a controller, response or closure object
     986            1 :             static google::protobuf::internal::FunctionClosure0 nullClosure(&google::protobuf::DoNothing, false);
     987            1 :             service->CallMethod(kMethod, controller, requestMessage, nullptr, &nullClosure);
     988              : 
     989            1 :             delete controller;
     990              :         }
     991              :         else
     992              :         {
     993              :             // create a response
     994           20 :             google::protobuf::Message *responseMessage = service->GetResponsePrototype(kMethod).New();
     995              : 
     996              :             // this is finally where we call the service implementation to process the request
     997           20 :             service->CallMethod(kMethod, controller, requestMessage, responseMessage,
     998              :                                 google::protobuf::NewCallback(this, &ServerImpl::handleResponse, controller,
     999              :                                                               responseMessage));
    1000              :         }
    1001              :     }
    1002              : 
    1003           21 :     delete requestMessage;
    1004              : }
    1005              : 
    1006              : // -----------------------------------------------------------------------------
    1007              : /*!
    1008              :     \internal
    1009              :     \threadsafe
    1010              : 
    1011              :     Sends a message to the given client if still connected.
    1012              : 
    1013              :  */
    1014           20 : void ServerImpl::sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg)
    1015              : {
    1016              :     // now take the lock (so the socket is not closed beneath us) and send the reply
    1017           20 :     std::lock_guard<std::mutex> locker(m_clientsLock);
    1018              : 
    1019           20 :     auto it = m_clients.find(clientId);
    1020           20 :     if (it == m_clients.end())
    1021              :     {
    1022            1 :         RIALTO_IPC_LOG_WARN("socket removed before error reply could be sent");
    1023              :     }
    1024           19 :     else if (it->second.sock < 0)
    1025              :     {
    1026            0 :         RIALTO_IPC_LOG_WARN("socket closed before error reply could be sent");
    1027              :     }
    1028           19 :     else if (!msg)
    1029              :     {
    1030            0 :         RIALTO_IPC_LOG_WARN("invalid msg to send on socket, ignoring");
    1031              :     }
    1032           19 :     else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, msg.get(), MSG_NOSIGNAL)) !=
    1033           19 :              static_cast<ssize_t>(msg->msg_iov->iov_len))
    1034              :     {
    1035            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete error reply message");
    1036              :     }
    1037           20 : }
    1038              : 
    1039              : // -----------------------------------------------------------------------------
    1040              : /*!
    1041              :     \internal
    1042              : 
    1043              :     Sends an error back to the client with a reason string in \a format.
    1044              : 
    1045              :  */
    1046            0 : void ServerImpl::sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
    1047              : {
    1048              :     // construct the error string
    1049              :     va_list ap;
    1050            0 :     va_start(ap, format);
    1051              :     char reason[512];
    1052            0 :     vsnprintf(reason, sizeof(reason), format, ap);
    1053            0 :     va_end(ap);
    1054              : 
    1055              :     // construct the reply message
    1056            0 :     auto msg = populateErrorReply(client, serialId, reason);
    1057              : 
    1058              :     // and send it
    1059            0 :     sendReply(client->id(), msg);
    1060              : }
    1061              : 
    1062              : // -----------------------------------------------------------------------------
    1063              : /*!
    1064              :     \internal
    1065              :     \threadsafe
    1066              : 
    1067              :     Called once the service handler code has processed the request and
    1068              :     completed it.  This may be called from a different thread if the service
    1069              :     implementation decided to off-load the request to a processing thread.
    1070              : 
    1071              :  */
    1072           20 : void ServerImpl::handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response)
    1073              : {
    1074           20 :     if (!controller || !controller->m_kClient)
    1075              :     {
    1076            0 :         RIALTO_IPC_LOG_ERROR("missing controller or attached client");
    1077            0 :         return;
    1078              :     }
    1079              : 
    1080           20 :     const std::shared_ptr<const ClientImpl> kClient = controller->m_kClient;
    1081           20 :     const uint64_t kClientId = kClient->id();
    1082              : 
    1083           20 :     std::shared_ptr<msghdr> message;
    1084           20 :     if (!controller->m_failed)
    1085              :     {
    1086           13 :         message = populateReply(kClient, controller->m_kSerialId, response);
    1087              :     }
    1088              :     else
    1089              :     {
    1090            7 :         message = populateErrorReply(kClient, controller->m_kSerialId, controller->m_failureReason);
    1091              :     }
    1092              : 
    1093              :     // no longer need the controller or the response objects
    1094           20 :     delete response;
    1095           20 :     delete controller;
    1096              : 
    1097              :     // send the reply message to the given client
    1098           20 :     sendReply(kClientId, message);
    1099              : }
    1100              : 
    1101              : // -----------------------------------------------------------------------------
    1102              : /*!
    1103              :     \internal
    1104              : 
    1105              :     Gets all the file descriptors that are stored in the \a response message
    1106              :     and returns them as a vector of ints.
    1107              : 
    1108              :     It works by iterating over the fields in the message, finding ones that are
    1109              :     marked as 'field_is_fd' and inserting the fd value into control part of the
    1110              :     message.  The actual integer sent in the data part of the message is set to
    1111              :     -1.
    1112              : 
    1113              :  */
    1114           17 : static std::vector<int> getResponseFileDescriptors(google::protobuf::Message *response)
    1115              : {
    1116           17 :     std::vector<int> fds;
    1117              : 
    1118              :     // process any file descriptors from the response message
    1119           17 :     const google::protobuf::Descriptor *kDescriptor = response->GetDescriptor();
    1120           17 :     const google::protobuf::Reflection *kReflection = nullptr;
    1121              : 
    1122           17 :     const int n = kDescriptor->field_count();
    1123           34 :     for (int i = 0; i < n; i++)
    1124              :     {
    1125           17 :         auto fieldDescriptor = kDescriptor->field(i);
    1126           17 :         if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
    1127              :         {
    1128            0 :             if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
    1129              :             {
    1130            0 :                 RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
    1131            0 :                 return {};
    1132              :             }
    1133              : 
    1134            0 :             if (!kReflection)
    1135              :             {
    1136            0 :                 kReflection = response->GetReflection();
    1137              :             }
    1138              : 
    1139            0 :             if (kReflection->HasField(*response, fieldDescriptor))
    1140              :             {
    1141            0 :                 fds.push_back(kReflection->GetInt32(*response, fieldDescriptor));
    1142            0 :                 kReflection->SetInt32(response, fieldDescriptor, -1);
    1143              :             }
    1144              :         }
    1145              :     }
    1146              : 
    1147           17 :     return fds;
    1148              : }
    1149              : 
    1150              : // -----------------------------------------------------------------------------
    1151              : /*!
    1152              :     \internal
    1153              :     \static
    1154              : 
    1155              :     Populates tha socket message buffer with the reply data for the RPC request.
    1156              : 
    1157              :  */
    1158           13 : std::shared_ptr<msghdr> ServerImpl::populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
    1159              :                                                   google::protobuf::Message *response)
    1160              : {
    1161              :     // create the base reply
    1162           13 :     transport::MessageFromServer message;
    1163           13 :     transport::MethodCallReply *reply = message.mutable_reply();
    1164           13 :     if (!reply)
    1165              :     {
    1166            0 :         RIALTO_IPC_LOG_ERROR("failed to create mutable reply object");
    1167            0 :         return nullptr;
    1168              :     }
    1169              : 
    1170              :     // convert the message response to a data string
    1171           13 :     std::string respString = response->SerializeAsString();
    1172              : 
    1173              :     // wrap in a transport response and send that
    1174           13 :     reply->set_reply_id(serialId);
    1175           13 :     reply->set_reply_message(std::move(respString));
    1176              : 
    1177              :     // next need to check if the response message has any file descriptors in
    1178              :     // it that need to be attached
    1179           13 :     const std::vector<int> kFds = getResponseFileDescriptors(response);
    1180           13 :     const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
    1181              : 
    1182              :     // calculate the size of the reply
    1183           13 :     const size_t kRequiredDataLen = message.ByteSizeLong();
    1184           13 :     if (kRequiredDataLen > m_kMaxMessageLen)
    1185              :     {
    1186            0 :         RIALTO_IPC_LOG_ERROR("reply exceeds maximum message limit (%zu, max %zu)", kRequiredDataLen, m_kMaxMessageLen);
    1187              : 
    1188              :         // error message is too big, replace with a generic error
    1189            0 :         return populateErrorReply(client, serialId, "Internal error - reply message to large");
    1190              :     }
    1191              : 
    1192              :     // build the socket message to send
    1193              :     auto msgBuf =
    1194           13 :         m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
    1195              : 
    1196           13 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1197           13 :     bzero(header, sizeof(msghdr));
    1198              : 
    1199           13 :     auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
    1200           13 :     header->msg_control = ctrl;
    1201           13 :     header->msg_controllen = kRequiredCtrlLen;
    1202              : 
    1203           13 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
    1204           13 :     header->msg_iov = iov;
    1205           13 :     header->msg_iovlen = 1;
    1206              : 
    1207           13 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
    1208           13 :     iov->iov_base = data;
    1209           13 :     iov->iov_len = kRequiredDataLen;
    1210              : 
    1211              :     // copy in the data
    1212           13 :     message.SerializeWithCachedSizesToArray(data);
    1213              : 
    1214              :     // add the fds
    1215           13 :     if (!kFds.empty())
    1216              :     {
    1217            0 :         struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
    1218            0 :         if (!cmsg)
    1219              :         {
    1220            0 :             RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
    1221            0 :             return nullptr;
    1222              :         }
    1223              : 
    1224            0 :         cmsg->cmsg_level = SOL_SOCKET;
    1225            0 :         cmsg->cmsg_type = SCM_RIGHTS;
    1226            0 :         cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
    1227            0 :         memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
    1228            0 :         header->msg_controllen = cmsg->cmsg_len;
    1229              :     }
    1230              : 
    1231           13 :     RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - { %s }", serialId, response->ShortDebugString().c_str());
    1232              : 
    1233              :     // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
    1234              :     // now do it manually
    1235           13 :     return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
    1236              : }
    1237              : 
    1238              : // -----------------------------------------------------------------------------
    1239              : /*!
    1240              :     \internal
    1241              : 
    1242              :     Populates tha socket message buffer with a reply error message.
    1243              : 
    1244              :  */
    1245            7 : std::shared_ptr<msghdr> ServerImpl::populateErrorReply(const std::shared_ptr<const ClientImpl> &client,
    1246              :                                                        uint64_t serialId, const std::string &reason)
    1247              : {
    1248              :     // create the base reply
    1249            7 :     transport::MessageFromServer message;
    1250            7 :     transport::MethodCallError *error = message.mutable_error();
    1251            7 :     error->set_reply_id(serialId);
    1252              :     error->set_error_reason(reason);
    1253              : 
    1254              :     // check the message will fit
    1255            7 :     size_t replySize = message.ByteSizeLong();
    1256            7 :     if (replySize > m_kMaxMessageLen)
    1257              :     {
    1258            0 :         RIALTO_IPC_LOG_ERROR("error reply exceeds max message size");
    1259              : 
    1260              :         // error message is to big, replace with a generic error
    1261              :         error->set_error_reason("Error message truncated");
    1262            0 :         replySize = message.ByteSizeLong();
    1263              :     }
    1264              : 
    1265            7 :     RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - \"%s\"", serialId, reason.c_str());
    1266              : 
    1267              :     // construct the message to send on the socket
    1268            7 :     auto msgBuf = m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + replySize);
    1269              : 
    1270            7 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1271            7 :     bzero(header, sizeof(msghdr));
    1272              : 
    1273            7 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr));
    1274            7 :     header->msg_iov = iov;
    1275            7 :     header->msg_iovlen = 1;
    1276              : 
    1277            7 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + sizeof(iovec));
    1278            7 :     iov->iov_base = data;
    1279            7 :     iov->iov_len = replySize;
    1280              : 
    1281              :     // serialise the reply to the message buffer
    1282            7 :     message.SerializeWithCachedSizesToArray(data);
    1283              : 
    1284              :     // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
    1285              :     // now do it manually
    1286           14 :     return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
    1287            7 : }
    1288              : 
    1289              : // -----------------------------------------------------------------------------
    1290              : /*!
    1291              :     \threadsafe
    1292              : 
    1293              :     Returns \c true if the client with \a clientId is currently connected to
    1294              :     the server.
    1295              : 
    1296              :  */
    1297           61 : bool ServerImpl::isClientConnected(uint64_t clientId) const
    1298              : {
    1299           61 :     std::unique_lock<std::mutex> locker(m_clientsLock);
    1300          122 :     return (m_clients.count(clientId) > 0);
    1301           61 : }
    1302              : 
    1303              : // -----------------------------------------------------------------------------
    1304              : /*!
    1305              :     \threadsafe
    1306              : 
    1307              :     May be called internally when there is an error on the socket, or externally
    1308              :     (possibly from a different thread) if the handler or service code decides to
    1309              :     close the client connection.
    1310              : 
    1311              :  */
    1312           29 : void ServerImpl::disconnectClient(uint64_t clientId)
    1313              : {
    1314              :     {
    1315           29 :         std::lock_guard<std::mutex> locker(m_clientsLock);
    1316           29 :         m_condemnedClients.insert(clientId);
    1317              :     }
    1318              : 
    1319           29 :     wakeEventLoop();
    1320              : }
    1321              : 
    1322              : // -----------------------------------------------------------------------------
    1323              : /*!
    1324              :     \threadsafe
    1325              : 
    1326              :     Called via the IAVBusClient interface when an async event should be sent.
    1327              : 
    1328              :     This may be called from any thread, or from within the rpc message handler.
    1329              : 
    1330              :     The \a clientId is the client to send the event to.
    1331              : 
    1332              :  */
    1333            4 : bool ServerImpl::sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &eventMessage)
    1334              : {
    1335              :     // gets the file descriptors from the event message
    1336            4 :     const std::vector<int> kFds = getResponseFileDescriptors(eventMessage.get());
    1337            4 :     const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
    1338              : 
    1339              :     // create the base reply
    1340            4 :     transport::MessageFromServer message;
    1341            4 :     transport::EventFromServer *event = message.mutable_event();
    1342            4 :     if (!event)
    1343              :     {
    1344            0 :         RIALTO_IPC_LOG_ERROR("failed to create mutable event object");
    1345            0 :         return false;
    1346              :     }
    1347              : 
    1348            8 :     event->set_event_name(eventMessage->GetTypeName());
    1349              : 
    1350              :     // convert the event to a data string
    1351            4 :     std::string respString = eventMessage->SerializeAsString();
    1352              : 
    1353              :     // wrap in a transport response and send that
    1354            4 :     event->set_message(std::move(respString));
    1355              : 
    1356              :     // check the reply will fit
    1357            4 :     size_t requiredDataLen = message.ByteSizeLong();
    1358            4 :     if (requiredDataLen > m_kMaxMessageLen)
    1359              :     {
    1360            0 :         RIALTO_IPC_LOG_ERROR("event message to big to fit in buffer (size %zu, max size %zu)", requiredDataLen,
    1361              :                              m_kMaxMessageLen);
    1362            0 :         return false;
    1363              :     }
    1364              : 
    1365              :     // build the socket message to send
    1366              :     auto msgBuf =
    1367            4 :         m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + requiredDataLen);
    1368              : 
    1369            4 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1370            4 :     bzero(header, sizeof(msghdr));
    1371              : 
    1372            4 :     auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
    1373            4 :     header->msg_control = ctrl;
    1374            4 :     header->msg_controllen = kRequiredCtrlLen;
    1375              : 
    1376            4 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
    1377            4 :     header->msg_iov = iov;
    1378            4 :     header->msg_iovlen = 1;
    1379              : 
    1380            4 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
    1381            4 :     iov->iov_base = data;
    1382            4 :     iov->iov_len = requiredDataLen;
    1383              : 
    1384              :     // copy in the data
    1385            4 :     message.SerializeWithCachedSizesToArray(data);
    1386              : 
    1387              :     // add the fds
    1388            4 :     if (!kFds.empty())
    1389              :     {
    1390            0 :         struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
    1391            0 :         if (!cmsg)
    1392              :         {
    1393            0 :             RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
    1394            0 :             return false;
    1395              :         }
    1396              : 
    1397            0 :         cmsg->cmsg_level = SOL_SOCKET;
    1398            0 :         cmsg->cmsg_type = SCM_RIGHTS;
    1399            0 :         cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
    1400            0 :         memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
    1401            0 :         header->msg_controllen = cmsg->cmsg_len;
    1402              :     }
    1403              : 
    1404              :     // finally, take the lock (so the socket is not closed beneath us) and send the reply
    1405              :     {
    1406            4 :         std::unique_lock<std::mutex> locker(m_clientsLock);
    1407              : 
    1408            4 :         auto it = m_clients.find(clientId);
    1409            4 :         if (it == m_clients.end() || it->second.sock < 0)
    1410              :         {
    1411            0 :             RIALTO_IPC_LOG_WARN("socket closed before event could be sent");
    1412            0 :             return false;
    1413              :         }
    1414            4 :         else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, header, MSG_NOSIGNAL)) !=
    1415            4 :                  static_cast<ssize_t>(requiredDataLen))
    1416              :         {
    1417            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete event message");
    1418            0 :             return false;
    1419              :         }
    1420            4 :     }
    1421              : 
    1422            4 :     RIALTO_IPC_LOG_DEBUG("event{ %s } - { %s }", eventMessage->GetTypeName().c_str(),
    1423              :                          eventMessage->ShortDebugString().c_str());
    1424              : 
    1425            4 :     return true;
    1426              : }
    1427              : } // namespace firebolt::rialto::ipc
        

Generated by: LCOV version 2.0-1