LCOV - code coverage report
Current view: top level - ipc/server/source - IpcServerImpl.cpp (source / functions) Coverage Total Hit
Test: coverage.info Lines: 62.5 % 563 352
Test Date: 2025-03-21 11:02:39 Functions: 86.2 % 29 25

            Line data    Source code
       1              : /*
       2              :  * If not stated otherwise in this file or this component's LICENSE file the
       3              :  * following copyright and licenses apply:
       4              :  *
       5              :  * Copyright 2022 Sky UK
       6              :  *
       7              :  * Licensed under the Apache License, Version 2.0 (the "License");
       8              :  * you may not use this file except in compliance with the License.
       9              :  * You may obtain a copy of the License at
      10              :  *
      11              :  * http://www.apache.org/licenses/LICENSE-2.0
      12              :  *
      13              :  * Unless required by applicable law or agreed to in writing, software
      14              :  * distributed under the License is distributed on an "AS IS" BASIS,
      15              :  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      16              :  * See the License for the specific language governing permissions and
      17              :  * limitations under the License.
      18              :  */
      19              : 
      20              : #include "IpcServerImpl.h"
      21              : #include "IIpcServerFactory.h"
      22              : #include "IpcClientImpl.h"
      23              : #include "IpcLogging.h"
      24              : #include "IpcServerControllerImpl.h"
      25              : 
      26              : #include "rialtoipc.pb.h"
      27              : 
      28              : #include <google/protobuf/service.h>
      29              : 
      30              : #include <algorithm>
      31              : #include <cinttypes>
      32              : #include <cstdarg>
      33              : #include <string>
      34              : #include <utility>
      35              : #include <vector>
      36              : 
      37              : #include <fcntl.h>
      38              : #include <poll.h>
      39              : #include <sys/epoll.h>
      40              : #include <sys/eventfd.h>
      41              : #include <sys/file.h>
      42              : #include <sys/socket.h>
      43              : #include <sys/stat.h>
      44              : #include <sys/un.h>
      45              : #include <unistd.h>
      46              : 
      47              : #define WAKE_EVENT_ID uint64_t(0)
      48              : #define FIRST_LISTENING_SOCKET_ID uint64_t(1)
      49              : #define FIRST_CLIENT_ID uint64_t(10000)
      50              : 
      51              : namespace firebolt::rialto::ipc
      52              : {
      53              : const size_t ServerImpl::m_kMaxMessageLen = (128 * 1024);
      54              : 
      55           32 : std::shared_ptr<IServerFactory> IServerFactory::createFactory()
      56              : {
      57           32 :     std::shared_ptr<IServerFactory> factory;
      58              :     try
      59              :     {
      60           32 :         factory = std::make_shared<ServerFactory>();
      61              :     }
      62            0 :     catch (const std::exception &e)
      63              :     {
      64            0 :         RIALTO_IPC_LOG_ERROR("Failed to create the server factory, reason: %s", e.what());
      65              :     }
      66              : 
      67           32 :     return factory;
      68              : }
      69              : 
      70           32 : std::shared_ptr<IServer> ServerFactory::create()
      71              : {
      72           32 :     return std::make_shared<ServerImpl>();
      73              : }
      74              : 
      75           32 : ServerImpl::ServerImpl()
      76           32 :     : m_pollFd(-1), m_wakeEventFd(-1), m_socketIdCounter(FIRST_LISTENING_SOCKET_ID), m_clientIdCounter(FIRST_CLIENT_ID),
      77           64 :       m_recvDataBuf{0}, m_recvCtrlBuf{0}
      78              : {
      79              :     // create the eventfd use to wake the poll loop
      80           32 :     m_wakeEventFd = eventfd(0, EFD_CLOEXEC);
      81           32 :     if (m_wakeEventFd < 0)
      82              :     {
      83            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd failed");
      84            0 :         return;
      85              :     }
      86              : 
      87              :     // create epoll loop
      88           32 :     m_pollFd = epoll_create1(EPOLL_CLOEXEC);
      89           32 :     if (m_pollFd < 0)
      90              :     {
      91            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_create1 failed");
      92            0 :         return;
      93              :     }
      94              : 
      95              :     // add the wake event to epoll
      96           32 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = WAKE_EVENT_ID}};
      97           32 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, m_wakeEventFd, &event) != 0)
      98              :     {
      99            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
     100              :     }
     101              : }
     102              : 
     103           32 : ServerImpl::~ServerImpl()
     104              : {
     105           32 :     if ((m_pollFd >= 0) && (close(m_pollFd) != 0))
     106            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close epoll");
     107              : 
     108           32 :     if ((m_wakeEventFd >= 0) && (close(m_wakeEventFd) != 0))
     109            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close eventfd");
     110              : 
     111           45 :     for (const auto &entry : m_sockets)
     112              :     {
     113           13 :         const Socket &kSocket = entry.second;
     114              : 
     115           13 :         if (kSocket.isOwned)
     116              :         {
     117            9 :             if (unlink(kSocket.sockPath.c_str()) != 0)
     118            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", kSocket.sockPath.c_str());
     119            9 :             if (close(kSocket.sockFd) != 0)
     120            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
     121              : 
     122            9 :             if (unlink(kSocket.lockPath.c_str()) != 0)
     123            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", kSocket.lockPath.c_str());
     124            9 :             if (close(kSocket.lockFd) != 0)
     125            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
     126              :         }
     127              :     }
     128           32 : }
     129              : 
     130              : // -----------------------------------------------------------------------------
     131              : /*!
     132              :     \static
     133              :     \internal
     134              : 
     135              :     Creates (if required) and takes the file lock associated with the socket
     136              :     path in the \a socket object.
     137              : 
     138              :  */
     139            9 : bool ServerImpl::getSocketLock(Socket *socket)
     140              : {
     141            9 :     std::string lockPath = socket->sockPath + ".lock";
     142            9 :     int fileDescriptor = open(lockPath.c_str(), O_CREAT | O_CLOEXEC, (S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP));
     143            9 :     if (fileDescriptor < 0)
     144              :     {
     145            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create / open lockfile @ '%s' (check permissions)", lockPath.c_str());
     146            0 :         return false;
     147              :     }
     148              : 
     149            9 :     if (flock(fileDescriptor, LOCK_EX | LOCK_NB) < 0)
     150              :     {
     151            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to lock lockfile @ '%s', maybe another server is running",
     152              :                                  lockPath.c_str());
     153            0 :         close(fileDescriptor);
     154            0 :         return false;
     155              :     }
     156              : 
     157            9 :     struct stat sbuf = {0};
     158            9 :     if (stat(socket->sockPath.c_str(), &sbuf) < 0)
     159              :     {
     160            9 :         if (errno != ENOENT)
     161              :         {
     162            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "did not manage to stat existing socket @ '%s'", socket->sockPath.c_str());
     163            0 :             close(fileDescriptor);
     164            0 :             return false;
     165              :         }
     166              :     }
     167            0 :     else if ((sbuf.st_mode & S_IWUSR) || (sbuf.st_mode & S_IWGRP))
     168              :     {
     169            0 :         unlink(socket->sockPath.c_str());
     170              :     }
     171              : 
     172            9 :     socket->lockFd = fileDescriptor;
     173            9 :     socket->lockPath = std::move(lockPath);
     174              : 
     175            9 :     return true;
     176              : }
     177              : 
     178              : // -----------------------------------------------------------------------------
     179              : /*!
     180              :     \static
     181              :     \internal
     182              : 
     183              :     Closes the file descriptors and the deletes the files stored in the \a socket
     184              :     object.
     185              : 
     186              :  */
     187            0 : void ServerImpl::closeListeningSocket(Socket *socket)
     188              : {
     189            0 :     if (!socket->isOwned)
     190              :     {
     191            0 :         return;
     192              :     }
     193              : 
     194            0 :     if (!socket->sockPath.empty() && (unlink(socket->sockPath.c_str()) != 0) && (errno != ENOENT))
     195            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", socket->sockPath.c_str());
     196            0 :     if ((socket->sockFd >= 0) && (close(socket->sockFd) != 0))
     197            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
     198              : 
     199            0 :     if (!socket->lockPath.empty() && (unlink(socket->lockPath.c_str()) != 0) && (errno != ENOENT))
     200            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", socket->lockPath.c_str());
     201            0 :     if ((socket->lockFd >= 0) && (close(socket->lockFd) != 0))
     202            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
     203              : 
     204            0 :     socket->sockFd = -1;
     205            0 :     socket->sockPath.clear();
     206              : 
     207            0 :     socket->lockFd = -1;
     208            0 :     socket->lockPath.clear();
     209              : }
     210              : 
     211            9 : bool ServerImpl::addSocket(const std::string &socketPath,
     212              :                            std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
     213              :                            std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
     214              : {
     215              :     // store the path
     216            9 :     Socket socket;
     217            9 :     socket.sockPath = socketPath;
     218              : 
     219              :     // create the socket
     220            9 :     socket.sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
     221            9 :     if (socket.sockFd == -1)
     222              :     {
     223            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
     224            0 :         return false;
     225              :     }
     226              : 
     227              :     // get the socket lock
     228            9 :     if (!getSocketLock(&socket))
     229              :     {
     230            0 :         closeListeningSocket(&socket);
     231            0 :         return false;
     232              :     }
     233              : 
     234              :     // bind to the given path
     235            9 :     struct sockaddr_un addr = {0};
     236            9 :     memset(&addr, 0x00, sizeof(addr));
     237            9 :     addr.sun_family = AF_UNIX;
     238            9 :     strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
     239              : 
     240            9 :     if (bind(socket.sockFd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
     241              :     {
     242            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
     243              : 
     244            0 :         closeListeningSocket(&socket);
     245            0 :         return false;
     246              :     }
     247              : 
     248              :     // put in listening mode
     249            9 :     if (listen(socket.sockFd, 1) == -1)
     250              :     {
     251            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
     252              : 
     253            0 :         closeListeningSocket(&socket);
     254            0 :         return false;
     255              :     }
     256              : 
     257              :     // create an id for the listening socket
     258            9 :     const uint64_t kSocketId = m_socketIdCounter++;
     259            9 :     if (kSocketId >= FIRST_CLIENT_ID)
     260              :     {
     261              :         // should never happen, we'd run out of file descriptors before
     262              :         // we hit the 10k limit on listening sockets
     263            0 :         RIALTO_IPC_LOG_ERROR("too many listening sockets");
     264              : 
     265            0 :         closeListeningSocket(&socket);
     266            0 :         return false;
     267              :     }
     268              : 
     269              :     // add the socket to epoll
     270            9 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
     271            9 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
     272              :     {
     273            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
     274              : 
     275            0 :         closeListeningSocket(&socket);
     276            0 :         return false;
     277              :     }
     278              : 
     279              :     // store the client connected / disconnected callbacks
     280            9 :     socket.connectedCb = clientConnectedCb;
     281            9 :     socket.disconnectedCb = clientDisconnectedCb;
     282              : 
     283              :     // add to the internal map
     284            9 :     std::lock_guard<std::mutex> locker(m_socketsLock);
     285            9 :     m_sockets.emplace(kSocketId, std::move(socket));
     286              : 
     287            9 :     RIALTO_IPC_LOG_INFO("added listening socket '%s' to server", socketPath.c_str());
     288              : 
     289            9 :     return true;
     290              : }
     291              : 
     292            4 : bool ServerImpl::addSocket(int fd, std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
     293              :                            std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
     294              : {
     295              :     // store the path
     296            4 :     Socket socket;
     297            4 :     socket.isOwned = false;
     298            4 :     socket.sockFd = fd;
     299              : 
     300              :     // put in listening mode
     301            4 :     if (listen(socket.sockFd, 1) == -1)
     302              :     {
     303            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
     304            0 :         return false;
     305              :     }
     306              : 
     307              :     // create an id for the listening socket
     308            4 :     const uint64_t kSocketId = m_socketIdCounter++;
     309            4 :     if (kSocketId >= FIRST_CLIENT_ID)
     310              :     {
     311              :         // should never happen, we'd run out of file descriptors before
     312              :         // we hit the 10k limit on listening sockets
     313            0 :         RIALTO_IPC_LOG_ERROR("too many listening sockets");
     314            0 :         return false;
     315              :     }
     316              : 
     317              :     // add the socket to epoll
     318            4 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
     319            4 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
     320              :     {
     321            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
     322            0 :         return false;
     323              :     }
     324              : 
     325              :     // store the client connected / disconnected callbacks
     326            4 :     socket.connectedCb = clientConnectedCb;
     327            4 :     socket.disconnectedCb = clientDisconnectedCb;
     328              : 
     329              :     // add to the internal map
     330            4 :     std::lock_guard<std::mutex> locker(m_socketsLock);
     331            4 :     m_sockets.emplace(kSocketId, std::move(socket));
     332              : 
     333            4 :     RIALTO_IPC_LOG_INFO("added listening socket with fd: %d to server", fd);
     334              : 
     335            4 :     return true;
     336              : }
     337              : 
     338           16 : std::shared_ptr<IClient> ServerImpl::addClient(int socketFd,
     339              :                                                std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
     340              : {
     341              :     // sanity check the supplied socket is of the right type
     342              :     struct sockaddr addr;
     343           16 :     socklen_t len = sizeof(sockaddr);
     344           16 :     if ((getsockname(socketFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
     345              :     {
     346            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
     347            0 :         return nullptr;
     348              :     }
     349           16 :     if (addr.sa_family != AF_UNIX)
     350              :     {
     351            0 :         RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
     352            0 :         return nullptr;
     353              :     }
     354              : 
     355           16 :     int type = 0;
     356           16 :     len = sizeof(type);
     357           16 :     if ((getsockopt(socketFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
     358              :     {
     359            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
     360            0 :         return nullptr;
     361              :     }
     362           16 :     if (type != SOCK_SEQPACKET)
     363              :     {
     364            0 :         RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
     365            0 :         return nullptr;
     366              :     }
     367              : 
     368              :     // dup the socket and set the O_CLOEXEC bit
     369           16 :     int duppedFd = fcntl(socketFd, F_DUPFD_CLOEXEC, 3);
     370           16 :     if (duppedFd < 0)
     371              :     {
     372            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
     373            0 :         return nullptr;
     374              :     }
     375              : 
     376              :     // ensure the SOCK_NONBLOCK flag is set
     377           16 :     int flags = fcntl(duppedFd, F_GETFL);
     378           16 :     if ((flags < 0) || (fcntl(duppedFd, F_SETFL, flags | O_NONBLOCK) < 0))
     379              :     {
     380            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
     381            0 :         close(duppedFd);
     382            0 :         return nullptr;
     383              :     }
     384              : 
     385              :     // finally, add the socket to the list of clients
     386           48 :     auto client = addClientSocket(duppedFd, "", std::move(clientDisconnectedCb));
     387           16 :     if (!client)
     388              :     {
     389            0 :         close(duppedFd);
     390              :     }
     391              : 
     392           16 :     return client;
     393              : }
     394              : 
     395            0 : int ServerImpl::fd() const
     396              : {
     397            0 :     return m_pollFd;
     398              : }
     399              : 
     400              : // -----------------------------------------------------------------------------
     401              : /*!
     402              :     \internal
     403              : 
     404              :     Writes to the eventfd to wake the event loop.  Typically called when requesting
     405              :     it to shutdown or external code has requested that a client be disconnected.
     406              : 
     407              :  */
     408           29 : void ServerImpl::wakeEventLoop() const
     409              : {
     410           29 :     if (m_wakeEventFd < 0)
     411              :     {
     412            0 :         RIALTO_IPC_LOG_ERROR("invalid wake event fd");
     413              :     }
     414              :     else
     415              :     {
     416           29 :         uint64_t value = 1;
     417           29 :         if (TEMP_FAILURE_RETRY(::write(m_wakeEventFd, &value, sizeof(value))) != sizeof(value))
     418              :         {
     419            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to the event fd");
     420              :         }
     421              :     }
     422           29 : }
     423              : 
     424         2893 : bool ServerImpl::wait(int timeoutMSecs)
     425              : {
     426         2893 :     if (m_pollFd < 0)
     427              :     {
     428            0 :         return false;
     429              :     }
     430              : 
     431              :     // wait for any event (with timeout)
     432              :     struct pollfd fds[2];
     433         2893 :     fds[0].fd = m_pollFd;
     434         2893 :     fds[0].events = POLLIN;
     435              : 
     436         2893 :     int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
     437         2893 :     if (rc < 0)
     438              :     {
     439            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
     440            0 :         return false;
     441              :     }
     442              : 
     443         2893 :     return true;
     444              : }
     445              : 
     446         2922 : bool ServerImpl::process()
     447              : {
     448         2922 :     if (m_pollFd < 0)
     449              :     {
     450            0 :         RIALTO_IPC_LOG_ERROR("missing epoll");
     451            0 :         return false;
     452              :     }
     453              : 
     454              :     // read up to 32 events
     455         2922 :     const int kMaxEvents = 32;
     456              :     struct epoll_event events[kMaxEvents];
     457              : 
     458         2922 :     int rc = TEMP_FAILURE_RETRY(epoll_wait(m_pollFd, events, kMaxEvents, 0));
     459         2922 :     if (rc < 0)
     460              :     {
     461            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
     462            0 :         return false;
     463              :     }
     464              : 
     465              :     // process the events (maybe 0 if timed out)
     466         2987 :     for (int i = 0; i < rc; i++)
     467              :     {
     468           65 :         const struct epoll_event &kEvent = events[i];
     469              : 
     470              :         // check if a wake event, in which case just clear the eventfd
     471           65 :         if (kEvent.data.u64 == WAKE_EVENT_ID)
     472              :         {
     473              :             uint64_t ignore;
     474            3 :             if (TEMP_FAILURE_RETRY(::read(m_wakeEventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
     475              :             {
     476            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to read wake eventfd");
     477              :             }
     478              :         }
     479              : 
     480              :         // check for events on the listening socket
     481           62 :         else if (kEvent.data.u64 < FIRST_CLIENT_ID)
     482              :         {
     483           13 :             if (kEvent.events & EPOLLIN)
     484           13 :                 processNewConnection(kEvent.data.u64);
     485           13 :             if (kEvent.events & EPOLLERR)
     486            0 :                 RIALTO_IPC_LOG_ERROR("error occurred on listening socket");
     487              :         }
     488              : 
     489              :         // otherwise, the event must have come from a socket
     490              :         else
     491              :         {
     492           49 :             processClientSocket(kEvent.data.u64, kEvent.events);
     493              :         }
     494              :     }
     495              : 
     496              :     // if we have client sockets that are condemned then we need to shut down
     497              :     // and close them as well as remove from epoll
     498         2922 :     std::unique_lock<std::mutex> locker(m_clientsLock);
     499         2922 :     if (!m_condemnedClients.empty())
     500              :     {
     501              :         // take a copy of the set so we can process without the lock held
     502           29 :         std::set<uint64_t> theCondemned;
     503           29 :         m_condemnedClients.swap(theCondemned);
     504              : 
     505           58 :         for (uint64_t clientId : theCondemned)
     506              :         {
     507           29 :             auto it = m_clients.find(clientId);
     508           29 :             if (it == m_clients.end())
     509              :             {
     510            0 :                 RIALTO_IPC_LOG_ERROR("failed to find condemned client");
     511            0 :                 continue;
     512              :             }
     513              : 
     514           29 :             ClientDetails details = it->second;
     515              : 
     516              :             // remove from the list of clients
     517           29 :             m_clients.erase(it);
     518              : 
     519              :             // drop the lock while closing the connection and removing from epoll
     520           29 :             locker.unlock();
     521              : 
     522              :             // remove the socket from epoll and close it
     523           29 :             if (details.sock >= 0)
     524              :             {
     525           29 :                 if (epoll_ctl(m_pollFd, EPOLL_CTL_DEL, details.sock, nullptr) != 0)
     526            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket from epoll");
     527              : 
     528           29 :                 if (shutdown(details.sock, SHUT_RDWR) != 0)
     529            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to shutdown socket");
     530           29 :                 if (close(details.sock) != 0)
     531            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket");
     532              :             }
     533              : 
     534              :             // let the installed handler know a client has disconnected
     535           29 :             if (details.disconnectedCb)
     536           29 :                 details.disconnectedCb(details.client);
     537              : 
     538              :             // ensure client object is destructed without the client lock held
     539           29 :             details.client.reset();
     540              : 
     541              :             // re-take the lock for the next client
     542           29 :             locker.lock();
     543              :         }
     544              :     }
     545              : 
     546         2922 :     return true;
     547              : }
     548              : 
     549              : // -----------------------------------------------------------------------------
     550              : /*!
     551              :     \internal
     552              : 
     553              :     Adds the \a socketFd to the internal list of client sockets.  This is called
     554              :     when a new connection is accepted on a listening socket, or when a client
     555              :     fd is added via ServerImpl::addClient(...).
     556              : 
     557              :     Returns a nullptr if failed to add the socket.
     558              : 
     559              :  */
     560           29 : std::shared_ptr<ClientImpl> ServerImpl::addClientSocket(int socketFd, const std::string &listeningSocketPath,
     561              :                                                         std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb)
     562              : {
     563              :     // get the client credentials
     564           29 :     struct ucred clientCreds = {0};
     565           29 :     socklen_t clientCredsLen = sizeof(clientCreds);
     566              : 
     567           58 :     if ((getsockopt(socketFd, SOL_SOCKET, SO_PEERCRED, &clientCreds, &clientCredsLen) < 0) ||
     568           29 :         (clientCredsLen != sizeof(clientCreds)))
     569              :     {
     570            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get client's details");
     571            0 :         return nullptr;
     572              :     }
     573              : 
     574              :     // create a new unique client id for the connection
     575           29 :     const uint64_t kClientId = m_clientIdCounter++;
     576              : 
     577              :     // add the new socket to the poll loop
     578           29 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kClientId}};
     579           29 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socketFd, &event) != 0)
     580              :     {
     581            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add client socket");
     582            0 :         return nullptr;
     583              :     }
     584              : 
     585              :     // create initial client object for the socket
     586           29 :     auto client = std::make_shared<ClientImpl>(shared_from_this(), kClientId, clientCreds);
     587              : 
     588              :     // and the details for the internal list
     589           29 :     ClientDetails clientDetails;
     590           29 :     clientDetails.sock = socketFd;
     591           29 :     clientDetails.disconnectedCb = std::move(disconnectedCb);
     592           29 :     clientDetails.client = client;
     593              : 
     594              :     // add to the set of clients
     595              :     {
     596           29 :         std::lock_guard<std::mutex> locker(m_clientsLock);
     597           29 :         m_clients.emplace(kClientId, clientDetails);
     598              :     }
     599              : 
     600           29 :     RIALTO_IPC_LOG_INFO("new client connected - giving id %" PRIu64, kClientId);
     601              : 
     602           29 :     return client;
     603              : }
     604              : 
     605              : // -----------------------------------------------------------------------------
     606              : /*!
     607              :     \internal
     608              : 
     609              :     Called when an event occurs on the listening socket. The code first accepts
     610              :     the connection and retrieves the client details, it then calls the installed
     611              :     handler to determine if we should drop this connection or not.
     612              : 
     613              : 
     614              :  */
     615           13 : void ServerImpl::processNewConnection(uint64_t socketId)
     616              : {
     617           13 :     RIALTO_IPC_LOG_DEBUG("processing new connection");
     618              : 
     619           13 :     std::unique_lock<std::mutex> socketLocker(m_socketsLock);
     620              : 
     621              :     // find matching socket object
     622           13 :     auto it = m_sockets.find(socketId);
     623           13 :     if (it == m_sockets.end())
     624              :     {
     625            0 :         RIALTO_IPC_LOG_ERROR("failed to find listening socket with id %" PRIu64, socketId);
     626            0 :         return;
     627              :     }
     628              : 
     629           13 :     const Socket &kSocket = it->second;
     630              : 
     631              :     // accept the connection from the client
     632           13 :     struct sockaddr clientAddr = {0};
     633           13 :     socklen_t clientAddrLen = sizeof(clientAddr);
     634              : 
     635           13 :     int clientSock = accept4(kSocket.sockFd, &clientAddr, &clientAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC);
     636           13 :     if (clientSock < 0)
     637              :     {
     638            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to accept client connection");
     639            0 :         return;
     640              :     }
     641              : 
     642           13 :     const std::string kSockPath = kSocket.sockPath;
     643           13 :     std::function<void(const std::shared_ptr<IClient> &)> connectedCb = kSocket.connectedCb;
     644           13 :     std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb = kSocket.disconnectedCb;
     645              : 
     646           13 :     socketLocker.unlock();
     647              : 
     648              :     // attempt to add the socket to the client list
     649           13 :     auto client = addClientSocket(clientSock, kSockPath, std::move(disconnectedCb));
     650           13 :     if (!client)
     651              :     {
     652            0 :         close(clientSock);
     653            0 :         return;
     654              :     }
     655              : 
     656              :     // notify the handler that a new connection has been made
     657           13 :     if (connectedCb)
     658              :     {
     659              :         // tell the handler we have a new client
     660           13 :         connectedCb(client);
     661              :     }
     662              : }
     663              : 
     664              : // -----------------------------------------------------------------------------
     665              : /*!
     666              :     \internal
     667              :     \static
     668              : 
     669              :     Reads all the file descriptors from a message. It returns the file descriptors
     670              :     as a vector of FileDescriptor objects, these objects safely store the fd and
     671              :     close them when they're destructed.
     672              : 
     673              :  */
     674            0 : static std::vector<FileDescriptor> readMessageFds(const struct msghdr *msg, size_t limit)
     675              : {
     676            0 :     std::vector<FileDescriptor> fds;
     677              : 
     678            0 :     for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr;
     679            0 :          cmsg = CMSG_NXTHDR(const_cast<struct msghdr *>(msg), cmsg))
     680              :     {
     681            0 :         if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
     682              :         {
     683            0 :             const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
     684            0 :             if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
     685              :             {
     686            0 :                 RIALTO_IPC_LOG_ERROR("invalid fd array size");
     687              :             }
     688              :             else
     689              :             {
     690            0 :                 const size_t n = kFdsLength / sizeof(int);
     691            0 :                 RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
     692              : 
     693            0 :                 fds.reserve(std::min(limit, n));
     694              : 
     695            0 :                 const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
     696            0 :                 for (size_t i = 0; i < n; i++)
     697              :                 {
     698            0 :                     RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
     699              : 
     700            0 :                     if (fds.size() >= limit)
     701              :                     {
     702            0 :                         RIALTO_IPC_LOG_ERROR("received to many file descriptors, "
     703              :                                              "exceeding max per message, closing left overs");
     704              :                     }
     705              :                     else
     706              :                     {
     707            0 :                         FileDescriptor fd(kFds[i]);
     708            0 :                         if (!fd.isValid())
     709              :                         {
     710            0 :                             RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
     711              :                         }
     712              :                         else
     713              :                         {
     714            0 :                             fds.emplace_back(std::move(fd));
     715              :                         }
     716              :                     }
     717              : 
     718            0 :                     if (close(kFds[i]) != 0)
     719              :                     {
     720            0 :                         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
     721              :                     }
     722              :                 }
     723              :             }
     724              :         }
     725              :     }
     726              : 
     727            0 :     return fds;
     728              : }
     729              : 
     730              : // -----------------------------------------------------------------------------
     731              : /*!
     732              :     \internal
     733              : 
     734              :     Processes an event from a client socket.
     735              : 
     736              :  */
     737           49 : void ServerImpl::processClientSocket(uint64_t clientId, unsigned events)
     738              : {
     739              :     // take the lock while accessing the client list
     740           49 :     std::unique_lock<std::mutex> locker(m_clientsLock);
     741              : 
     742           49 :     auto it = m_clients.find(clientId);
     743           49 :     if (it == m_clients.end())
     744              :     {
     745              :         // should never happen
     746            0 :         RIALTO_IPC_LOG_ERROR("received an event from a socket with no matching client");
     747            0 :         return;
     748              :     }
     749              : 
     750              :     // check if the client is marked for closure, if so then just ignore the data
     751           49 :     if (m_condemnedClients.count(clientId) != 0)
     752              :     {
     753            0 :         return;
     754              :     }
     755              : 
     756              :     // get the socket that corresponds to the client connection
     757           49 :     const int kSockFd = it->second.sock;
     758              : 
     759              :     // get the client object
     760           49 :     std::shared_ptr<ClientImpl> clientObj = it->second.client;
     761              : 
     762              :     // can safely release the lock now we have the clientId and client object
     763           49 :     locker.unlock();
     764              : 
     765              :     // if there was an error disconnect the socket
     766           49 :     if (events & EPOLLERR)
     767              :     {
     768            0 :         RIALTO_IPC_LOG_ERROR("error detected on client socket - disconnecting client");
     769            0 :         disconnectClient(clientId);
     770            0 :         return;
     771              :     }
     772              : 
     773           49 :     if (events & EPOLLIN)
     774              :     {
     775              :         // read all messages from the client socket, we break out if the socket is closed
     776              :         // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
     777              :         while (true)
     778              :         {
     779           70 :             struct msghdr msg = {nullptr};
     780           70 :             struct iovec io = {.iov_base = m_recvDataBuf, .iov_len = sizeof(m_recvDataBuf)};
     781              : 
     782           70 :             bzero(&msg, sizeof(msg));
     783           70 :             msg.msg_iov = &io;
     784           70 :             msg.msg_iovlen = 1;
     785           70 :             msg.msg_control = m_recvCtrlBuf;
     786           70 :             msg.msg_controllen = sizeof(m_recvCtrlBuf);
     787              : 
     788              :             // read one message
     789           70 :             ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(kSockFd, &msg, MSG_CMSG_CLOEXEC));
     790           70 :             if (rd < 0)
     791              :             {
     792           21 :                 if (errno != EWOULDBLOCK)
     793              :                 {
     794            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
     795            0 :                     disconnectClient(clientId);
     796              :                 }
     797              : 
     798           49 :                 break;
     799              :             }
     800           49 :             else if (rd == 0)
     801              :             {
     802              :                 // client closed connection, and we've read all data, add to the condemned set
     803              :                 // so is cleaned up once all the events are processed
     804           28 :                 disconnectClient(clientId);
     805              : 
     806           28 :                 break;
     807              :             }
     808           21 :             else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
     809              :             {
     810            0 :                 RIALTO_IPC_LOG_WARN("received message from client %" PRIu64 " truncated, discarding", clientId);
     811              : 
     812              :                 // make sure to close all the fds, otherwise we'll leak them
     813            0 :                 readMessageFds(&msg, 16);
     814              :             }
     815              :             else
     816              :             {
     817              :                 // if there is control data then assume fd(s) have been passed
     818           21 :                 if (msg.msg_controllen > 0)
     819              :                 {
     820            0 :                     processClientMessage(clientObj, m_recvDataBuf, rd, readMessageFds(&msg, 16));
     821              :                 }
     822              :                 else
     823              :                 {
     824           21 :                     processClientMessage(clientObj, m_recvDataBuf, rd);
     825              :                 }
     826              :             }
     827              :         }
     828              :     }
     829           49 : }
     830              : 
     831              : // -----------------------------------------------------------------------------
     832              : /*!
     833              :     \internal
     834              : 
     835              :     Places the received file descriptors into the request message.
     836              : 
     837              :     It works by iterating over the fields in the message, finding ones that are
     838              :     marked as 'field_is_fd' and then replacing the received integer value with
     839              :     an actual file descriptor.
     840              : 
     841              :  */
     842           21 : static bool addRequestFileDescriptors(google::protobuf::Message *request, const std::vector<FileDescriptor> &requestFds)
     843              : {
     844           21 :     auto fdIterator = requestFds.begin();
     845              : 
     846           21 :     const google::protobuf::Descriptor *kDescriptor = request->GetDescriptor();
     847           21 :     const google::protobuf::Reflection *kReflection = nullptr;
     848              : 
     849           21 :     const int n = kDescriptor->field_count();
     850           75 :     for (int i = 0; i < n; i++)
     851              :     {
     852           54 :         auto fieldDescriptor = kDescriptor->field(i);
     853           54 :         if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
     854              :         {
     855            0 :             if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
     856              :             {
     857            0 :                 RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
     858            0 :                 return false;
     859              :             }
     860              : 
     861            0 :             if (!kReflection)
     862              :             {
     863            0 :                 kReflection = request->GetReflection();
     864              :             }
     865              : 
     866            0 :             if (kReflection->HasField(*request, fieldDescriptor))
     867              :             {
     868            0 :                 if (fdIterator == requestFds.end())
     869              :                 {
     870            0 :                     RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but one was supplied");
     871            0 :                     return false;
     872              :                 }
     873              : 
     874            0 :                 kReflection->SetInt32(request, fieldDescriptor, fdIterator->fd());
     875            0 :                 ++fdIterator;
     876              :             }
     877              :         }
     878              :     }
     879              : 
     880           21 :     if (fdIterator != requestFds.end())
     881              :     {
     882            0 :         RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
     883            0 :         return false;
     884              :     }
     885              : 
     886           21 :     return true;
     887              : }
     888              : 
     889              : // -----------------------------------------------------------------------------
     890              : /*!
     891              :     \internal
     892              : 
     893              :     Processes a message received on a client socket.
     894              : 
     895              :  */
     896           21 : void ServerImpl::processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
     897              :                                       const std::vector<FileDescriptor> &fds)
     898              : {
     899           21 :     RIALTO_IPC_LOG_DEBUG("processing client message of size %zu bytes (%zu fds) from client %" PRId64, dataLen,
     900              :                          fds.size(), client->id());
     901              : 
     902              :     // parse the message
     903           21 :     transport::MessageToServer message;
     904           21 :     if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
     905              :     {
     906            0 :         RIALTO_IPC_LOG_ERROR("invalid request");
     907            0 :         return;
     908              :     }
     909              : 
     910           21 :     if (message.has_call())
     911              :     {
     912           21 :         processMethodCall(client, message.call(), fds);
     913              :     }
     914              :     else
     915              :     {
     916            0 :         RIALTO_IPC_LOG_WARN("received unknown message type from client");
     917              :     }
     918           21 : }
     919              : 
     920              : // -----------------------------------------------------------------------------
     921              : /*!
     922              :     \internal
     923              : 
     924              :     Processes a method call requst from a client.
     925              : 
     926              :  */
     927           21 : void ServerImpl::processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
     928              :                                    const std::vector<FileDescriptor> &fds)
     929              : {
     930              :     // try and find the service with the given name
     931           21 :     const std::string &kServiceName = call.service_name();
     932           21 :     auto it = client->m_services.find(kServiceName);
     933           21 :     if (it == client->m_services.end())
     934              :     {
     935            0 :         RIALTO_IPC_LOG_ERROR("unknown service request '%s'", kServiceName.c_str());
     936              : 
     937            0 :         sendErrorReply(client, call.serial_id(), "Unknown service '%s'", kServiceName.c_str());
     938            0 :         return;
     939              :     }
     940              : 
     941           21 :     std::shared_ptr<google::protobuf::Service> service = it->second;
     942              : 
     943              :     // try and find the method
     944           21 :     const std::string &kMethodName = call.method_name();
     945           21 :     const google::protobuf::MethodDescriptor *kMethod = service->GetDescriptor()->FindMethodByName(kMethodName);
     946           21 :     if (!kMethod)
     947              :     {
     948            0 :         RIALTO_IPC_LOG_ERROR("no method with name '%s'", kMethodName.c_str());
     949              : 
     950            0 :         sendErrorReply(client, call.serial_id(), "Unknown method '%s'", kMethodName.c_str());
     951            0 :         return;
     952              :     }
     953              : 
     954              :     // check if the method is expecting a reply
     955           21 :     const bool kNoReply = kMethod->options().HasExtension(no_reply) && kMethod->options().GetExtension(no_reply);
     956              : 
     957              :     // parse the request data
     958           21 :     google::protobuf::Message *requestMessage = service->GetRequestPrototype(kMethod).New();
     959           21 :     if (!requestMessage->ParseFromString(call.request_message()))
     960              :     {
     961            0 :         RIALTO_IPC_LOG_ERROR("failed to parse method from array");
     962              :     }
     963           21 :     else if (!addRequestFileDescriptors(requestMessage, fds))
     964              :     {
     965            0 :         RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the request");
     966              :     }
     967              :     else
     968              :     {
     969           21 :         RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", call.serial_id(), kServiceName.c_str(),
     970              :                              kMethodName.c_str(), requestMessage->ShortDebugString().c_str());
     971              : 
     972           21 :         auto *controller = new ServerControllerImpl(client, call.serial_id());
     973              : 
     974           21 :         if (kNoReply)
     975              :         {
     976              :             // we should not send a reply for this call, so call the code to handle the
     977              :             // request, but no need to pass a controller, response or closure object
     978            1 :             static google::protobuf::internal::FunctionClosure0 nullClosure(&google::protobuf::DoNothing, false);
     979            1 :             service->CallMethod(kMethod, controller, requestMessage, nullptr, &nullClosure);
     980              : 
     981            1 :             delete controller;
     982              :         }
     983              :         else
     984              :         {
     985              :             // create a response
     986           20 :             google::protobuf::Message *responseMessage = service->GetResponsePrototype(kMethod).New();
     987              : 
     988              :             // this is finally where we call the service implementation to process the request
     989           20 :             service->CallMethod(kMethod, controller, requestMessage, responseMessage,
     990              :                                 google::protobuf::NewCallback(this, &ServerImpl::handleResponse, controller,
     991              :                                                               responseMessage));
     992              :         }
     993              :     }
     994              : 
     995           21 :     delete requestMessage;
     996              : }
     997              : 
     998              : // -----------------------------------------------------------------------------
     999              : /*!
    1000              :     \internal
    1001              :     \threadsafe
    1002              : 
    1003              :     Sends a message to the given client if still connected.
    1004              : 
    1005              :  */
    1006           20 : void ServerImpl::sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg)
    1007              : {
    1008              :     // now take the lock (so the socket is not closed beneath us) and send the reply
    1009           20 :     std::lock_guard<std::mutex> locker(m_clientsLock);
    1010              : 
    1011           20 :     auto it = m_clients.find(clientId);
    1012           20 :     if (it == m_clients.end())
    1013              :     {
    1014            1 :         RIALTO_IPC_LOG_WARN("socket removed before error reply could be sent");
    1015              :     }
    1016           19 :     else if (it->second.sock < 0)
    1017              :     {
    1018            0 :         RIALTO_IPC_LOG_WARN("socket closed before error reply could be sent");
    1019              :     }
    1020           19 :     else if (!msg)
    1021              :     {
    1022            0 :         RIALTO_IPC_LOG_WARN("invalid msg to send on socket, ignoring");
    1023              :     }
    1024           19 :     else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, msg.get(), MSG_NOSIGNAL)) !=
    1025           19 :              static_cast<ssize_t>(msg->msg_iov->iov_len))
    1026              :     {
    1027            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete error reply message");
    1028              :     }
    1029           20 : }
    1030              : 
    1031              : // -----------------------------------------------------------------------------
    1032              : /*!
    1033              :     \internal
    1034              : 
    1035              :     Sends an error back to the client with a reason string in \a format.
    1036              : 
    1037              :  */
    1038            0 : void ServerImpl::sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
    1039              : {
    1040              :     // construct the error string
    1041              :     va_list ap;
    1042            0 :     va_start(ap, format);
    1043              :     char reason[512];
    1044            0 :     vsnprintf(reason, sizeof(reason), format, ap);
    1045            0 :     va_end(ap);
    1046              : 
    1047              :     // construct the reply message
    1048            0 :     auto msg = populateErrorReply(client, serialId, reason);
    1049              : 
    1050              :     // and send it
    1051            0 :     sendReply(client->id(), msg);
    1052              : }
    1053              : 
    1054              : // -----------------------------------------------------------------------------
    1055              : /*!
    1056              :     \internal
    1057              :     \threadsafe
    1058              : 
    1059              :     Called once the service handler code has processed the request and
    1060              :     completed it.  This may be called from a different thread if the service
    1061              :     implementation decided to off-load the request to a processing thread.
    1062              : 
    1063              :  */
    1064           20 : void ServerImpl::handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response)
    1065              : {
    1066           20 :     if (!controller || !controller->m_kClient)
    1067              :     {
    1068            0 :         RIALTO_IPC_LOG_ERROR("missing controller or attached client");
    1069            0 :         return;
    1070              :     }
    1071              : 
    1072           20 :     const std::shared_ptr<const ClientImpl> kClient = controller->m_kClient;
    1073           20 :     const uint64_t kClientId = kClient->id();
    1074              : 
    1075           20 :     std::shared_ptr<msghdr> message;
    1076           20 :     if (!controller->m_failed)
    1077              :     {
    1078           13 :         message = populateReply(kClient, controller->m_kSerialId, response);
    1079              :     }
    1080              :     else
    1081              :     {
    1082            7 :         message = populateErrorReply(kClient, controller->m_kSerialId, controller->m_failureReason);
    1083              :     }
    1084              : 
    1085              :     // no longer need the controller or the response objects
    1086           20 :     delete response;
    1087           20 :     delete controller;
    1088              : 
    1089              :     // send the reply message to the given client
    1090           20 :     sendReply(kClientId, message);
    1091              : }
    1092              : 
    1093              : // -----------------------------------------------------------------------------
    1094              : /*!
    1095              :     \internal
    1096              : 
    1097              :     Gets all the file descriptors that are stored in the \a response message
    1098              :     and returns them as a vector of ints.
    1099              : 
    1100              :     It works by iterating over the fields in the message, finding ones that are
    1101              :     marked as 'field_is_fd' and inserting the fd value into control part of the
    1102              :     message.  The actual integer sent in the data part of the message is set to
    1103              :     -1.
    1104              : 
    1105              :  */
    1106           17 : static std::vector<int> getResponseFileDescriptors(google::protobuf::Message *response)
    1107              : {
    1108           17 :     std::vector<int> fds;
    1109              : 
    1110              :     // process any file descriptors from the response message
    1111           17 :     const google::protobuf::Descriptor *kDescriptor = response->GetDescriptor();
    1112           17 :     const google::protobuf::Reflection *kReflection = nullptr;
    1113              : 
    1114           17 :     const int n = kDescriptor->field_count();
    1115           34 :     for (int i = 0; i < n; i++)
    1116              :     {
    1117           17 :         auto fieldDescriptor = kDescriptor->field(i);
    1118           17 :         if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
    1119              :         {
    1120            0 :             if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
    1121              :             {
    1122            0 :                 RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
    1123            0 :                 return {};
    1124              :             }
    1125              : 
    1126            0 :             if (!kReflection)
    1127              :             {
    1128            0 :                 kReflection = response->GetReflection();
    1129              :             }
    1130              : 
    1131            0 :             if (kReflection->HasField(*response, fieldDescriptor))
    1132              :             {
    1133            0 :                 fds.push_back(kReflection->GetInt32(*response, fieldDescriptor));
    1134            0 :                 kReflection->SetInt32(response, fieldDescriptor, -1);
    1135              :             }
    1136              :         }
    1137              :     }
    1138              : 
    1139           17 :     return fds;
    1140              : }
    1141              : 
    1142              : // -----------------------------------------------------------------------------
    1143              : /*!
    1144              :     \internal
    1145              :     \static
    1146              : 
    1147              :     Populates tha socket message buffer with the reply data for the RPC request.
    1148              : 
    1149              :  */
    1150           13 : std::shared_ptr<msghdr> ServerImpl::populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
    1151              :                                                   google::protobuf::Message *response)
    1152              : {
    1153              :     // create the base reply
    1154           13 :     transport::MessageFromServer message;
    1155           13 :     transport::MethodCallReply *reply = message.mutable_reply();
    1156           13 :     if (!reply)
    1157              :     {
    1158            0 :         RIALTO_IPC_LOG_ERROR("failed to create mutable reply object");
    1159            0 :         return nullptr;
    1160              :     }
    1161              : 
    1162              :     // convert the message response to a data string
    1163           13 :     std::string respString = response->SerializeAsString();
    1164              : 
    1165              :     // wrap in a transport response and send that
    1166           13 :     reply->set_reply_id(serialId);
    1167           13 :     reply->set_reply_message(std::move(respString));
    1168              : 
    1169              :     // next need to check if the response message has any file descriptors in
    1170              :     // it that need to be attached
    1171           13 :     const std::vector<int> kFds = getResponseFileDescriptors(response);
    1172           13 :     const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
    1173              : 
    1174              :     // calculate the size of the reply
    1175           13 :     const size_t kRequiredDataLen = message.ByteSizeLong();
    1176           13 :     if (kRequiredDataLen > m_kMaxMessageLen)
    1177              :     {
    1178            0 :         RIALTO_IPC_LOG_ERROR("reply exceeds maximum message limit (%zu, max %zu)", kRequiredDataLen, m_kMaxMessageLen);
    1179              : 
    1180              :         // error message is too big, replace with a generic error
    1181            0 :         return populateErrorReply(client, serialId, "Internal error - reply message to large");
    1182              :     }
    1183              : 
    1184              :     // build the socket message to send
    1185              :     auto msgBuf =
    1186           13 :         m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
    1187              : 
    1188           13 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1189           13 :     bzero(header, sizeof(msghdr));
    1190              : 
    1191           13 :     auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
    1192           13 :     header->msg_control = ctrl;
    1193           13 :     header->msg_controllen = kRequiredCtrlLen;
    1194              : 
    1195           13 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
    1196           13 :     header->msg_iov = iov;
    1197           13 :     header->msg_iovlen = 1;
    1198              : 
    1199           13 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
    1200           13 :     iov->iov_base = data;
    1201           13 :     iov->iov_len = kRequiredDataLen;
    1202              : 
    1203              :     // copy in the data
    1204           13 :     message.SerializeWithCachedSizesToArray(data);
    1205              : 
    1206              :     // add the fds
    1207           13 :     if (!kFds.empty())
    1208              :     {
    1209            0 :         struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
    1210            0 :         if (!cmsg)
    1211              :         {
    1212            0 :             RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
    1213            0 :             return nullptr;
    1214              :         }
    1215              : 
    1216            0 :         cmsg->cmsg_level = SOL_SOCKET;
    1217            0 :         cmsg->cmsg_type = SCM_RIGHTS;
    1218            0 :         cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
    1219            0 :         memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
    1220            0 :         header->msg_controllen = cmsg->cmsg_len;
    1221              :     }
    1222              : 
    1223           13 :     RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - { %s }", serialId, response->ShortDebugString().c_str());
    1224              : 
    1225              :     // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
    1226              :     // now do it manually
    1227           13 :     return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
    1228              : }
    1229              : 
    1230              : // -----------------------------------------------------------------------------
    1231              : /*!
    1232              :     \internal
    1233              : 
    1234              :     Populates tha socket message buffer with a reply error message.
    1235              : 
    1236              :  */
    1237            7 : std::shared_ptr<msghdr> ServerImpl::populateErrorReply(const std::shared_ptr<const ClientImpl> &client,
    1238              :                                                        uint64_t serialId, const std::string &reason)
    1239              : {
    1240              :     // create the base reply
    1241            7 :     transport::MessageFromServer message;
    1242            7 :     transport::MethodCallError *error = message.mutable_error();
    1243            7 :     error->set_reply_id(serialId);
    1244              :     error->set_error_reason(reason);
    1245              : 
    1246              :     // check the message will fit
    1247            7 :     size_t replySize = message.ByteSizeLong();
    1248            7 :     if (replySize > m_kMaxMessageLen)
    1249              :     {
    1250            0 :         RIALTO_IPC_LOG_ERROR("error reply exceeds max message size");
    1251              : 
    1252              :         // error message is to big, replace with a generic error
    1253              :         error->set_error_reason("Error message truncated");
    1254            0 :         replySize = message.ByteSizeLong();
    1255              :     }
    1256              : 
    1257            7 :     RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - \"%s\"", serialId, reason.c_str());
    1258              : 
    1259              :     // construct the message to send on the socket
    1260            7 :     auto msgBuf = m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + replySize);
    1261              : 
    1262            7 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1263            7 :     bzero(header, sizeof(msghdr));
    1264              : 
    1265            7 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr));
    1266            7 :     header->msg_iov = iov;
    1267            7 :     header->msg_iovlen = 1;
    1268              : 
    1269            7 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + sizeof(iovec));
    1270            7 :     iov->iov_base = data;
    1271            7 :     iov->iov_len = replySize;
    1272              : 
    1273              :     // serialise the reply to the message buffer
    1274            7 :     message.SerializeWithCachedSizesToArray(data);
    1275              : 
    1276              :     // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
    1277              :     // now do it manually
    1278           14 :     return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
    1279            7 : }
    1280              : 
    1281              : // -----------------------------------------------------------------------------
    1282              : /*!
    1283              :     \threadsafe
    1284              : 
    1285              :     Returns \c true if the client with \a clientId is currently connected to
    1286              :     the server.
    1287              : 
    1288              :  */
    1289           61 : bool ServerImpl::isClientConnected(uint64_t clientId) const
    1290              : {
    1291           61 :     std::unique_lock<std::mutex> locker(m_clientsLock);
    1292          122 :     return (m_clients.count(clientId) > 0);
    1293           61 : }
    1294              : 
    1295              : // -----------------------------------------------------------------------------
    1296              : /*!
    1297              :     \threadsafe
    1298              : 
    1299              :     May be called internally when there is an error on the socket, or externally
    1300              :     (possibly from a different thread) if the handler or service code decides to
    1301              :     close the client connection.
    1302              : 
    1303              :  */
    1304           29 : void ServerImpl::disconnectClient(uint64_t clientId)
    1305              : {
    1306           29 :     std::unique_lock<std::mutex> locker(m_clientsLock);
    1307           29 :     m_condemnedClients.insert(clientId);
    1308           29 :     locker.unlock();
    1309              : 
    1310           29 :     wakeEventLoop();
    1311              : }
    1312              : 
    1313              : // -----------------------------------------------------------------------------
    1314              : /*!
    1315              :     \threadsafe
    1316              : 
    1317              :     Called via the IAVBusClient interface when an async event should be sent.
    1318              : 
    1319              :     This may be called from any thread, or from within the rpc message handler.
    1320              : 
    1321              :     The \a clientId is the client to send the event to.
    1322              : 
    1323              :  */
    1324            4 : bool ServerImpl::sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &eventMessage)
    1325              : {
    1326              :     // gets the file descriptors from the event message
    1327            4 :     const std::vector<int> kFds = getResponseFileDescriptors(eventMessage.get());
    1328            4 :     const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
    1329              : 
    1330              :     // create the base reply
    1331            4 :     transport::MessageFromServer message;
    1332            4 :     transport::EventFromServer *event = message.mutable_event();
    1333            4 :     if (!event)
    1334              :     {
    1335            0 :         RIALTO_IPC_LOG_ERROR("failed to create mutable event object");
    1336            0 :         return false;
    1337              :     }
    1338              : 
    1339            8 :     event->set_event_name(eventMessage->GetTypeName());
    1340              : 
    1341              :     // convert the event to a data string
    1342            4 :     std::string respString = eventMessage->SerializeAsString();
    1343              : 
    1344              :     // wrap in a transport response and send that
    1345            4 :     event->set_message(std::move(respString));
    1346              : 
    1347              :     // check the reply will fit
    1348            4 :     size_t requiredDataLen = message.ByteSizeLong();
    1349            4 :     if (requiredDataLen > m_kMaxMessageLen)
    1350              :     {
    1351            0 :         RIALTO_IPC_LOG_ERROR("event message to big to fit in buffer (size %zu, max size %zu)", requiredDataLen,
    1352              :                              m_kMaxMessageLen);
    1353            0 :         return false;
    1354              :     }
    1355              : 
    1356              :     // build the socket message to send
    1357              :     auto msgBuf =
    1358            4 :         m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + requiredDataLen);
    1359              : 
    1360            4 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1361            4 :     bzero(header, sizeof(msghdr));
    1362              : 
    1363            4 :     auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
    1364            4 :     header->msg_control = ctrl;
    1365            4 :     header->msg_controllen = kRequiredCtrlLen;
    1366              : 
    1367            4 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
    1368            4 :     header->msg_iov = iov;
    1369            4 :     header->msg_iovlen = 1;
    1370              : 
    1371            4 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
    1372            4 :     iov->iov_base = data;
    1373            4 :     iov->iov_len = requiredDataLen;
    1374              : 
    1375              :     // copy in the data
    1376            4 :     message.SerializeWithCachedSizesToArray(data);
    1377              : 
    1378              :     // add the fds
    1379            4 :     if (!kFds.empty())
    1380              :     {
    1381            0 :         struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
    1382            0 :         if (!cmsg)
    1383              :         {
    1384            0 :             RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
    1385            0 :             return false;
    1386              :         }
    1387              : 
    1388            0 :         cmsg->cmsg_level = SOL_SOCKET;
    1389            0 :         cmsg->cmsg_type = SCM_RIGHTS;
    1390            0 :         cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
    1391            0 :         memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
    1392            0 :         header->msg_controllen = cmsg->cmsg_len;
    1393              :     }
    1394              : 
    1395              :     // finally, take the lock (so the socket is not closed beneath us) and send the reply
    1396            4 :     std::unique_lock<std::mutex> locker(m_clientsLock);
    1397              : 
    1398            4 :     auto it = m_clients.find(clientId);
    1399            4 :     if (it == m_clients.end() || it->second.sock < 0)
    1400              :     {
    1401            0 :         RIALTO_IPC_LOG_WARN("socket closed before event could be sent");
    1402            0 :         return false;
    1403              :     }
    1404            4 :     else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, header, MSG_NOSIGNAL)) != static_cast<ssize_t>(requiredDataLen))
    1405              :     {
    1406            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete event message");
    1407            0 :         return false;
    1408              :     }
    1409              : 
    1410            4 :     locker.unlock();
    1411              : 
    1412            4 :     RIALTO_IPC_LOG_DEBUG("event{ %s } - { %s }", eventMessage->GetTypeName().c_str(),
    1413              :                          eventMessage->ShortDebugString().c_str());
    1414              : 
    1415            4 :     return true;
    1416              : }
    1417              : } // namespace firebolt::rialto::ipc
        

Generated by: LCOV version 2.0-1