LCOV - code coverage report
Current view: top level - ipc/server/source - IpcServerImpl.cpp (source / functions) Coverage Total Hit
Test: coverage.info Lines: 62.4 % 540 337
Test Date: 2025-02-18 13:13:53 Functions: 85.7 % 28 24

            Line data    Source code
       1              : /*
       2              :  * If not stated otherwise in this file or this component's LICENSE file the
       3              :  * following copyright and licenses apply:
       4              :  *
       5              :  * Copyright 2022 Sky UK
       6              :  *
       7              :  * Licensed under the Apache License, Version 2.0 (the "License");
       8              :  * you may not use this file except in compliance with the License.
       9              :  * You may obtain a copy of the License at
      10              :  *
      11              :  * http://www.apache.org/licenses/LICENSE-2.0
      12              :  *
      13              :  * Unless required by applicable law or agreed to in writing, software
      14              :  * distributed under the License is distributed on an "AS IS" BASIS,
      15              :  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      16              :  * See the License for the specific language governing permissions and
      17              :  * limitations under the License.
      18              :  */
      19              : 
      20              : #include "IpcServerImpl.h"
      21              : #include "IIpcServerFactory.h"
      22              : #include "IpcClientImpl.h"
      23              : #include "IpcLogging.h"
      24              : #include "IpcServerControllerImpl.h"
      25              : 
      26              : #include "rialtoipc.pb.h"
      27              : 
      28              : #include <google/protobuf/service.h>
      29              : 
      30              : #include <algorithm>
      31              : #include <cinttypes>
      32              : #include <cstdarg>
      33              : #include <string>
      34              : #include <utility>
      35              : #include <vector>
      36              : 
      37              : #include <fcntl.h>
      38              : #include <poll.h>
      39              : #include <sys/epoll.h>
      40              : #include <sys/eventfd.h>
      41              : #include <sys/file.h>
      42              : #include <sys/socket.h>
      43              : #include <sys/stat.h>
      44              : #include <sys/un.h>
      45              : #include <unistd.h>
      46              : 
      47              : #define WAKE_EVENT_ID uint64_t(0)
      48              : #define FIRST_LISTENING_SOCKET_ID uint64_t(1)
      49              : #define FIRST_CLIENT_ID uint64_t(10000)
      50              : 
      51              : namespace firebolt::rialto::ipc
      52              : {
      53              : const size_t ServerImpl::m_kMaxMessageLen = (128 * 1024);
      54              : 
      55           26 : std::shared_ptr<IServerFactory> IServerFactory::createFactory()
      56              : {
      57           26 :     std::shared_ptr<IServerFactory> factory;
      58              :     try
      59              :     {
      60           26 :         factory = std::make_shared<ServerFactory>();
      61              :     }
      62            0 :     catch (const std::exception &e)
      63              :     {
      64            0 :         RIALTO_IPC_LOG_ERROR("Failed to create the server factory, reason: %s", e.what());
      65              :     }
      66              : 
      67           26 :     return factory;
      68              : }
      69              : 
      70           26 : std::shared_ptr<IServer> ServerFactory::create()
      71              : {
      72           26 :     return std::make_shared<ServerImpl>();
      73              : }
      74              : 
      75           26 : ServerImpl::ServerImpl()
      76           26 :     : m_pollFd(-1), m_wakeEventFd(-1), m_socketIdCounter(FIRST_LISTENING_SOCKET_ID),
      77           26 :       m_clientIdCounter(FIRST_CLIENT_ID), m_recvDataBuf{0}, m_recvCtrlBuf{0}
      78              : {
      79              :     // create the eventfd use to wake the poll loop
      80           26 :     m_wakeEventFd = eventfd(0, EFD_CLOEXEC);
      81           26 :     if (m_wakeEventFd < 0)
      82              :     {
      83            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd failed");
      84            0 :         return;
      85              :     }
      86              : 
      87              :     // create epoll loop
      88           26 :     m_pollFd = epoll_create1(EPOLL_CLOEXEC);
      89           26 :     if (m_pollFd < 0)
      90              :     {
      91            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_create1 failed");
      92            0 :         return;
      93              :     }
      94              : 
      95              :     // add the wake event to epoll
      96           26 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = WAKE_EVENT_ID}};
      97           26 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, m_wakeEventFd, &event) != 0)
      98              :     {
      99            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
     100              :     }
     101              : }
     102              : 
     103           26 : ServerImpl::~ServerImpl()
     104              : {
     105           26 :     if ((m_pollFd >= 0) && (close(m_pollFd) != 0))
     106            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close epoll");
     107              : 
     108           26 :     if ((m_wakeEventFd >= 0) && (close(m_wakeEventFd) != 0))
     109            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close eventfd");
     110              : 
     111           35 :     for (const auto &entry : m_sockets)
     112              :     {
     113            9 :         const Socket &kSocket = entry.second;
     114              : 
     115            9 :         if (unlink(kSocket.sockPath.c_str()) != 0)
     116            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", kSocket.sockPath.c_str());
     117            9 :         if (close(kSocket.sockFd) != 0)
     118            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
     119              : 
     120            9 :         if (unlink(kSocket.lockPath.c_str()) != 0)
     121            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", kSocket.lockPath.c_str());
     122            9 :         if (close(kSocket.lockFd) != 0)
     123            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
     124              :     }
     125           26 : }
     126              : 
     127              : // -----------------------------------------------------------------------------
     128              : /*!
     129              :     \static
     130              :     \internal
     131              : 
     132              :     Creates (if required) and takes the file lock associated with the socket
     133              :     path in the \a socket object.
     134              : 
     135              :  */
     136            9 : bool ServerImpl::getSocketLock(Socket *socket)
     137              : {
     138            9 :     std::string lockPath = socket->sockPath + ".lock";
     139            9 :     int fd = open(lockPath.c_str(), O_CREAT | O_CLOEXEC, (S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP));
     140            9 :     if (fd < 0)
     141              :     {
     142            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create / open lockfile @ '%s' (check permissions)", lockPath.c_str());
     143            0 :         return false;
     144              :     }
     145              : 
     146            9 :     if (flock(fd, LOCK_EX | LOCK_NB) < 0)
     147              :     {
     148            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to lock lockfile @ '%s', maybe another server is running",
     149              :                                  lockPath.c_str());
     150            0 :         close(fd);
     151            0 :         return false;
     152              :     }
     153              : 
     154            9 :     struct stat sbuf = {0};
     155            9 :     if (stat(socket->sockPath.c_str(), &sbuf) < 0)
     156              :     {
     157            9 :         if (errno != ENOENT)
     158              :         {
     159            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "did not manage to stat existing socket @ '%s'", socket->sockPath.c_str());
     160            0 :             close(fd);
     161            0 :             return false;
     162              :         }
     163              :     }
     164            0 :     else if ((sbuf.st_mode & S_IWUSR) || (sbuf.st_mode & S_IWGRP))
     165              :     {
     166            0 :         unlink(socket->sockPath.c_str());
     167              :     }
     168              : 
     169            9 :     socket->lockFd = fd;
     170            9 :     socket->lockPath = std::move(lockPath);
     171              : 
     172            9 :     return true;
     173              : }
     174              : 
     175              : // -----------------------------------------------------------------------------
     176              : /*!
     177              :     \static
     178              :     \internal
     179              : 
     180              :     Closes the file descriptors and the deletes the files stored in the \a socket
     181              :     object.
     182              : 
     183              :  */
     184            0 : void ServerImpl::closeListeningSocket(Socket *socket)
     185              : {
     186            0 :     if (!socket->sockPath.empty() && (unlink(socket->sockPath.c_str()) != 0) && (errno != ENOENT))
     187            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", socket->sockPath.c_str());
     188            0 :     if ((socket->sockFd >= 0) && (close(socket->sockFd) != 0))
     189            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
     190              : 
     191            0 :     if (!socket->lockPath.empty() && (unlink(socket->lockPath.c_str()) != 0) && (errno != ENOENT))
     192            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", socket->lockPath.c_str());
     193            0 :     if ((socket->lockFd >= 0) && (close(socket->lockFd) != 0))
     194            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
     195              : 
     196            0 :     socket->sockFd = -1;
     197            0 :     socket->sockPath.clear();
     198              : 
     199            0 :     socket->lockFd = -1;
     200            0 :     socket->lockPath.clear();
     201              : }
     202              : 
     203            9 : bool ServerImpl::addSocket(const std::string &socketPath,
     204              :                            std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
     205              :                            std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
     206              : {
     207              :     // store the path
     208            9 :     Socket socket;
     209            9 :     socket.sockPath = socketPath;
     210              : 
     211              :     // create the socket
     212            9 :     socket.sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
     213            9 :     if (socket.sockFd == -1)
     214              :     {
     215            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
     216            0 :         return false;
     217              :     }
     218              : 
     219              :     // get the socket lock
     220            9 :     if (!getSocketLock(&socket))
     221              :     {
     222            0 :         closeListeningSocket(&socket);
     223            0 :         return false;
     224              :     }
     225              : 
     226              :     // bind to the given path
     227            9 :     struct sockaddr_un addr = {0};
     228            9 :     memset(&addr, 0x00, sizeof(addr));
     229            9 :     addr.sun_family = AF_UNIX;
     230            9 :     strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
     231              : 
     232            9 :     if (bind(socket.sockFd, (struct sockaddr *)&addr, sizeof(addr)) == -1)
     233              :     {
     234            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
     235              : 
     236            0 :         closeListeningSocket(&socket);
     237            0 :         return false;
     238              :     }
     239              : 
     240              :     // put in listening mode
     241            9 :     if (listen(socket.sockFd, 1) == -1)
     242              :     {
     243            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
     244              : 
     245            0 :         closeListeningSocket(&socket);
     246            0 :         return false;
     247              :     }
     248              : 
     249              :     // create an id for the listening socket
     250            9 :     const uint64_t kSocketId = m_socketIdCounter++;
     251            9 :     if (kSocketId >= FIRST_CLIENT_ID)
     252              :     {
     253              :         // should never happen, we'd run out of file descriptors before
     254              :         // we hit the 10k limit on listening sockets
     255            0 :         RIALTO_IPC_LOG_ERROR("too many listening sockets");
     256              : 
     257            0 :         closeListeningSocket(&socket);
     258            0 :         return false;
     259              :     }
     260              : 
     261              :     // add the socket to epoll
     262            9 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
     263            9 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
     264              :     {
     265            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
     266              : 
     267            0 :         closeListeningSocket(&socket);
     268            0 :         return false;
     269              :     }
     270              : 
     271              :     // store the client connected / disconnected callbacks
     272            9 :     socket.connectedCb = clientConnectedCb;
     273            9 :     socket.disconnectedCb = clientDisconnectedCb;
     274              : 
     275              :     // add to the internal map
     276            9 :     std::lock_guard<std::mutex> locker(m_socketsLock);
     277            9 :     m_sockets.emplace(kSocketId, std::move(socket));
     278              : 
     279            9 :     RIALTO_IPC_LOG_INFO("added listening socket '%s' to server", socketPath.c_str());
     280              : 
     281            9 :     return true;
     282              : }
     283              : 
     284           14 : std::shared_ptr<IClient> ServerImpl::addClient(int socketFd,
     285              :                                                std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
     286              : {
     287              :     // sanity check the supplied socket is of the right type
     288              :     struct sockaddr addr;
     289           14 :     socklen_t len = sizeof(sockaddr);
     290           14 :     if ((getsockname(socketFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
     291              :     {
     292            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
     293            0 :         return nullptr;
     294              :     }
     295           14 :     if (addr.sa_family != AF_UNIX)
     296              :     {
     297            0 :         RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
     298            0 :         return nullptr;
     299              :     }
     300              : 
     301           14 :     int type = 0;
     302           14 :     len = sizeof(type);
     303           14 :     if ((getsockopt(socketFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
     304              :     {
     305            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
     306            0 :         return nullptr;
     307              :     }
     308           14 :     if (type != SOCK_SEQPACKET)
     309              :     {
     310            0 :         RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
     311            0 :         return nullptr;
     312              :     }
     313              : 
     314              :     // dup the socket and set the O_CLOEXEC bit
     315           14 :     int duppedFd = fcntl(socketFd, F_DUPFD_CLOEXEC, 3);
     316           14 :     if (duppedFd < 0)
     317              :     {
     318            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
     319            0 :         return nullptr;
     320              :     }
     321              : 
     322              :     // ensure the SOCK_NONBLOCK flag is set
     323           14 :     int flags = fcntl(duppedFd, F_GETFL);
     324           14 :     if ((flags < 0) || (fcntl(duppedFd, F_SETFL, flags | O_NONBLOCK) < 0))
     325              :     {
     326            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
     327            0 :         close(duppedFd);
     328            0 :         return nullptr;
     329              :     }
     330              : 
     331              :     // finally, add the socket to the list of clients
     332           28 :     auto client = addClientSocket(duppedFd, "", std::move(clientDisconnectedCb));
     333           14 :     if (!client)
     334              :     {
     335            0 :         close(duppedFd);
     336              :     }
     337              : 
     338           14 :     return client;
     339              : }
     340              : 
     341            0 : int ServerImpl::fd() const
     342              : {
     343            0 :     return m_pollFd;
     344              : }
     345              : 
     346              : // -----------------------------------------------------------------------------
     347              : /*!
     348              :     \internal
     349              : 
     350              :     Writes to the eventfd to wake the event loop.  Typically called when requesting
     351              :     it to shutdown or external code has requested that a client be disconnected.
     352              : 
     353              :  */
     354           23 : void ServerImpl::wakeEventLoop() const
     355              : {
     356           23 :     if (m_wakeEventFd < 0)
     357              :     {
     358            0 :         RIALTO_IPC_LOG_ERROR("invalid wake event fd");
     359              :     }
     360              :     else
     361              :     {
     362           23 :         uint64_t value = 1;
     363           23 :         if (TEMP_FAILURE_RETRY(::write(m_wakeEventFd, &value, sizeof(value))) != sizeof(value))
     364              :         {
     365            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to the event fd");
     366              :         }
     367              :     }
     368           23 : }
     369              : 
     370         2879 : bool ServerImpl::wait(int timeoutMSecs)
     371              : {
     372         2879 :     if (m_pollFd < 0)
     373              :     {
     374            0 :         return false;
     375              :     }
     376              : 
     377              :     // wait for any event (with timeout)
     378              :     struct pollfd fds[2];
     379         2879 :     fds[0].fd = m_pollFd;
     380         2879 :     fds[0].events = POLLIN;
     381              : 
     382         2879 :     int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
     383         2879 :     if (rc < 0)
     384              :     {
     385            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
     386            0 :         return false;
     387              :     }
     388              : 
     389         2879 :     return true;
     390              : }
     391              : 
     392         2902 : bool ServerImpl::process()
     393              : {
     394         2902 :     if (m_pollFd < 0)
     395              :     {
     396            0 :         RIALTO_IPC_LOG_ERROR("missing epoll");
     397            0 :         return false;
     398              :     }
     399              : 
     400              :     // read up to 32 events
     401         2902 :     const int kMaxEvents = 32;
     402              :     struct epoll_event events[kMaxEvents];
     403              : 
     404         2902 :     int rc = TEMP_FAILURE_RETRY(epoll_wait(m_pollFd, events, kMaxEvents, 0));
     405         2902 :     if (rc < 0)
     406              :     {
     407            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
     408            0 :         return false;
     409              :     }
     410              : 
     411              :     // process the events (maybe 0 if timed out)
     412         2951 :     for (int i = 0; i < rc; i++)
     413              :     {
     414           49 :         const struct epoll_event &kEvent = events[i];
     415              : 
     416              :         // check if a wake event, in which case just clear the eventfd
     417           49 :         if (kEvent.data.u64 == WAKE_EVENT_ID)
     418              :         {
     419              :             uint64_t ignore;
     420            3 :             if (TEMP_FAILURE_RETRY(::read(m_wakeEventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
     421              :             {
     422            0 :                 RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to read wake eventfd");
     423              :             }
     424              :         }
     425              : 
     426              :         // check for events on the listening socket
     427           46 :         else if (kEvent.data.u64 < FIRST_CLIENT_ID)
     428              :         {
     429            9 :             if (kEvent.events & EPOLLIN)
     430            9 :                 processNewConnection(kEvent.data.u64);
     431            9 :             if (kEvent.events & EPOLLERR)
     432            0 :                 RIALTO_IPC_LOG_ERROR("error occurred on listening socket");
     433              :         }
     434              : 
     435              :         // otherwise, the event must have come from a socket
     436              :         else
     437              :         {
     438           37 :             processClientSocket(kEvent.data.u64, kEvent.events);
     439              :         }
     440              :     }
     441              : 
     442              :     // if we have client sockets that are condemned then we need to shut down
     443              :     // and close them as well as remove from epoll
     444         2902 :     std::unique_lock<std::mutex> locker(m_clientsLock);
     445         2902 :     if (!m_condemnedClients.empty())
     446              :     {
     447              :         // take a copy of the set so we can process without the lock held
     448           23 :         std::set<uint64_t> theCondemned;
     449           23 :         m_condemnedClients.swap(theCondemned);
     450              : 
     451           46 :         for (uint64_t clientId : theCondemned)
     452              :         {
     453           23 :             auto it = m_clients.find(clientId);
     454           23 :             if (it == m_clients.end())
     455              :             {
     456            0 :                 RIALTO_IPC_LOG_ERROR("failed to find condemned client");
     457            0 :                 continue;
     458              :             }
     459              : 
     460           23 :             ClientDetails details = it->second;
     461              : 
     462              :             // remove from the list of clients
     463           23 :             m_clients.erase(it);
     464              : 
     465              :             // drop the lock while closing the connection and removing from epoll
     466           23 :             locker.unlock();
     467              : 
     468              :             // remove the socket from epoll and close it
     469           23 :             if (details.sock >= 0)
     470              :             {
     471           23 :                 if (epoll_ctl(m_pollFd, EPOLL_CTL_DEL, details.sock, nullptr) != 0)
     472            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket from epoll");
     473              : 
     474           23 :                 if (shutdown(details.sock, SHUT_RDWR) != 0)
     475            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to shutdown socket");
     476           23 :                 if (close(details.sock) != 0)
     477            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket");
     478              :             }
     479              : 
     480              :             // let the installed handler know a client has disconnected
     481           23 :             if (details.disconnectedCb)
     482           23 :                 details.disconnectedCb(details.client);
     483              : 
     484              :             // ensure client object is destructed without the client lock held
     485           23 :             details.client.reset();
     486              : 
     487              :             // re-take the lock for the next client
     488           23 :             locker.lock();
     489              :         }
     490              :     }
     491              : 
     492         2902 :     return true;
     493              : }
     494              : 
     495              : // -----------------------------------------------------------------------------
     496              : /*!
     497              :     \internal
     498              : 
     499              :     Adds the \a socketFd to the internal list of client sockets.  This is called
     500              :     when a new connection is accepted on a listening socket, or when a client
     501              :     fd is added via ServerImpl::addClient(...).
     502              : 
     503              :     Returns a nullptr if failed to add the socket.
     504              : 
     505              :  */
     506           23 : std::shared_ptr<ClientImpl> ServerImpl::addClientSocket(int socketFd, const std::string &listeningSocketPath,
     507              :                                                         std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb)
     508              : {
     509              :     // get the client credentials
     510           23 :     struct ucred clientCreds = {0};
     511           23 :     socklen_t clientCredsLen = sizeof(clientCreds);
     512              : 
     513           46 :     if ((getsockopt(socketFd, SOL_SOCKET, SO_PEERCRED, &clientCreds, &clientCredsLen) < 0) ||
     514           23 :         (clientCredsLen != sizeof(clientCreds)))
     515              :     {
     516            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get client's details");
     517            0 :         return nullptr;
     518              :     }
     519              : 
     520              :     // create a new unique client id for the connection
     521           23 :     const uint64_t kClientId = m_clientIdCounter++;
     522              : 
     523              :     // add the new socket to the poll loop
     524           23 :     epoll_event event = {.events = EPOLLIN, .data = {.u64 = kClientId}};
     525           23 :     if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socketFd, &event) != 0)
     526              :     {
     527            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add client socket");
     528            0 :         return nullptr;
     529              :     }
     530              : 
     531              :     // create initial client object for the socket
     532           23 :     auto client = std::make_shared<ClientImpl>(shared_from_this(), kClientId, clientCreds);
     533              : 
     534              :     // and the details for the internal list
     535           23 :     ClientDetails clientDetails;
     536           23 :     clientDetails.sock = socketFd;
     537           23 :     clientDetails.disconnectedCb = std::move(disconnectedCb);
     538           23 :     clientDetails.client = client;
     539              : 
     540              :     // add to the set of clients
     541              :     {
     542           23 :         std::lock_guard<std::mutex> locker(m_clientsLock);
     543           23 :         m_clients.emplace(kClientId, clientDetails);
     544              :     }
     545              : 
     546           23 :     RIALTO_IPC_LOG_INFO("new client connected - giving id %" PRIu64, kClientId);
     547              : 
     548           23 :     return client;
     549              : }
     550              : 
     551              : // -----------------------------------------------------------------------------
     552              : /*!
     553              :     \internal
     554              : 
     555              :     Called when an event occurs on the listening socket. The code first accepts
     556              :     the connection and retrieves the client details, it then calls the installed
     557              :     handler to determine if we should drop this connection or not.
     558              : 
     559              : 
     560              :  */
     561            9 : void ServerImpl::processNewConnection(uint64_t socketId)
     562              : {
     563            9 :     RIALTO_IPC_LOG_DEBUG("processing new connection");
     564              : 
     565            9 :     std::unique_lock<std::mutex> socketLocker(m_socketsLock);
     566              : 
     567              :     // find matching socket object
     568            9 :     auto it = m_sockets.find(socketId);
     569            9 :     if (it == m_sockets.end())
     570              :     {
     571            0 :         RIALTO_IPC_LOG_ERROR("failed to find listening socket with id %" PRIu64, socketId);
     572            0 :         return;
     573              :     }
     574              : 
     575            9 :     const Socket &kSocket = it->second;
     576              : 
     577              :     // accept the connection from the client
     578            9 :     struct sockaddr clientAddr = {0};
     579            9 :     socklen_t clientAddrLen = sizeof(clientAddr);
     580              : 
     581            9 :     int clientSock = accept4(kSocket.sockFd, &clientAddr, &clientAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC);
     582            9 :     if (clientSock < 0)
     583              :     {
     584            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to accept client connection");
     585            0 :         return;
     586              :     }
     587              : 
     588            9 :     const std::string kSockPath = kSocket.sockPath;
     589            9 :     std::function<void(const std::shared_ptr<IClient> &)> connectedCb = kSocket.connectedCb;
     590            9 :     std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb = kSocket.disconnectedCb;
     591              : 
     592            9 :     socketLocker.unlock();
     593              : 
     594              :     // attempt to add the socket to the client list
     595            9 :     auto client = addClientSocket(clientSock, kSockPath, std::move(disconnectedCb));
     596            9 :     if (!client)
     597              :     {
     598            0 :         close(clientSock);
     599            0 :         return;
     600              :     }
     601              : 
     602              :     // notify the handler that a new connection has been made
     603            9 :     if (connectedCb)
     604              :     {
     605              :         // tell the handler we have a new client
     606            9 :         connectedCb(client);
     607              :     }
     608              : }
     609              : 
     610              : // -----------------------------------------------------------------------------
     611              : /*!
     612              :     \internal
     613              :     \static
     614              : 
     615              :     Reads all the file descriptors from a message. It returns the file descriptors
     616              :     as a vector of FileDescriptor objects, these objects safely store the fd and
     617              :     close them when they're destructed.
     618              : 
     619              :  */
     620            0 : static std::vector<FileDescriptor> readMessageFds(const struct msghdr *msg, size_t limit)
     621              : {
     622            0 :     std::vector<FileDescriptor> fds;
     623              : 
     624            0 :     for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr; cmsg = CMSG_NXTHDR((struct msghdr *)msg, cmsg))
     625              :     {
     626            0 :         if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
     627              :         {
     628            0 :             const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
     629            0 :             if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
     630              :             {
     631            0 :                 RIALTO_IPC_LOG_ERROR("invalid fd array size");
     632              :             }
     633              :             else
     634              :             {
     635            0 :                 const size_t n = kFdsLength / sizeof(int);
     636            0 :                 RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
     637              : 
     638            0 :                 fds.reserve(std::min(limit, n));
     639              : 
     640            0 :                 const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
     641            0 :                 for (size_t i = 0; i < n; i++)
     642              :                 {
     643            0 :                     RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
     644              : 
     645            0 :                     if (fds.size() >= limit)
     646              :                     {
     647            0 :                         RIALTO_IPC_LOG_ERROR("received to many file descriptors, "
     648              :                                              "exceeding max per message, closing left overs");
     649              :                     }
     650              :                     else
     651              :                     {
     652            0 :                         FileDescriptor fd(kFds[i]);
     653            0 :                         if (!fd.isValid())
     654              :                         {
     655            0 :                             RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
     656              :                         }
     657              :                         else
     658              :                         {
     659            0 :                             fds.emplace_back(std::move(fd));
     660              :                         }
     661              :                     }
     662              : 
     663            0 :                     if (close(kFds[i]) != 0)
     664              :                     {
     665            0 :                         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
     666              :                     }
     667              :                 }
     668              :             }
     669              :         }
     670              :     }
     671              : 
     672            0 :     return fds;
     673              : }
     674              : 
     675              : // -----------------------------------------------------------------------------
     676              : /*!
     677              :     \internal
     678              : 
     679              :     Processes an event from a client socket.
     680              : 
     681              :  */
     682           37 : void ServerImpl::processClientSocket(uint64_t clientId, unsigned events)
     683              : {
     684              :     // take the lock while accessing the client list
     685           37 :     std::unique_lock<std::mutex> locker(m_clientsLock);
     686              : 
     687           37 :     auto it = m_clients.find(clientId);
     688           37 :     if (it == m_clients.end())
     689              :     {
     690              :         // should never happen
     691            0 :         RIALTO_IPC_LOG_ERROR("received an event from a socket with no matching client");
     692            0 :         return;
     693              :     }
     694              : 
     695              :     // check if the client is marked for closure, if so then just ignore the data
     696           37 :     if (m_condemnedClients.count(clientId) != 0)
     697              :     {
     698            0 :         return;
     699              :     }
     700              : 
     701              :     // get the socket that corresponds to the client connection
     702           37 :     const int kSockFd = it->second.sock;
     703              : 
     704              :     // get the client object
     705           37 :     std::shared_ptr<ClientImpl> clientObj = it->second.client;
     706              : 
     707              :     // can safely release the lock now we have the clientId and client object
     708           37 :     locker.unlock();
     709              : 
     710              :     // if there was an error disconnect the socket
     711           37 :     if (events & EPOLLERR)
     712              :     {
     713            0 :         RIALTO_IPC_LOG_ERROR("error detected on client socket - disconnecting client");
     714            0 :         disconnectClient(clientId);
     715            0 :         return;
     716              :     }
     717              : 
     718           37 :     if (events & EPOLLIN)
     719              :     {
     720              :         // read all messages from the client socket, we break out if the socket is closed
     721              :         // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
     722              :         while (true)
     723              :         {
     724           52 :             struct msghdr msg = {nullptr};
     725           52 :             struct iovec io = {.iov_base = m_recvDataBuf, .iov_len = sizeof(m_recvDataBuf)};
     726              : 
     727           52 :             bzero(&msg, sizeof(msg));
     728           52 :             msg.msg_iov = &io;
     729           52 :             msg.msg_iovlen = 1;
     730           52 :             msg.msg_control = m_recvCtrlBuf;
     731           52 :             msg.msg_controllen = sizeof(m_recvCtrlBuf);
     732              : 
     733              :             // read one message
     734           52 :             ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(kSockFd, &msg, MSG_CMSG_CLOEXEC));
     735           52 :             if (rd < 0)
     736              :             {
     737           15 :                 if (errno != EWOULDBLOCK)
     738              :                 {
     739            0 :                     RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
     740            0 :                     disconnectClient(clientId);
     741              :                 }
     742              : 
     743           37 :                 break;
     744              :             }
     745           37 :             else if (rd == 0)
     746              :             {
     747              :                 // client closed connection, and we've read all data, add to the condemned set
     748              :                 // so is cleaned up once all the events are processed
     749           22 :                 disconnectClient(clientId);
     750              : 
     751           22 :                 break;
     752              :             }
     753           15 :             else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
     754              :             {
     755            0 :                 RIALTO_IPC_LOG_WARN("received message from client %" PRIu64 " truncated, discarding", clientId);
     756              : 
     757              :                 // make sure to close all the fds, otherwise we'll leak them
     758            0 :                 readMessageFds(&msg, 16);
     759              :             }
     760              :             else
     761              :             {
     762              :                 // if there is control data then assume fd(s) have been passed
     763           15 :                 if (msg.msg_controllen > 0)
     764              :                 {
     765            0 :                     processClientMessage(clientObj, m_recvDataBuf, rd, readMessageFds(&msg, 16));
     766              :                 }
     767              :                 else
     768              :                 {
     769           15 :                     processClientMessage(clientObj, m_recvDataBuf, rd);
     770              :                 }
     771              :             }
     772              :         }
     773              :     }
     774           37 : }
     775              : 
     776              : // -----------------------------------------------------------------------------
     777              : /*!
     778              :     \internal
     779              : 
     780              :     Places the received file descriptors into the request message.
     781              : 
     782              :     It works by iterating over the fields in the message, finding ones that are
     783              :     marked as 'field_is_fd' and then replacing the received integer value with
     784              :     an actual file descriptor.
     785              : 
     786              :  */
     787           15 : static bool addRequestFileDescriptors(google::protobuf::Message *request, const std::vector<FileDescriptor> &requestFds)
     788              : {
     789           15 :     auto fdIterator = requestFds.begin();
     790              : 
     791           15 :     const google::protobuf::Descriptor *kDescriptor = request->GetDescriptor();
     792           15 :     const google::protobuf::Reflection *kReflection = nullptr;
     793              : 
     794           15 :     const int n = kDescriptor->field_count();
     795           47 :     for (int i = 0; i < n; i++)
     796              :     {
     797           32 :         auto fieldDescriptor = kDescriptor->field(i);
     798           32 :         if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
     799              :         {
     800            0 :             if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
     801              :             {
     802            0 :                 RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
     803            0 :                 return false;
     804              :             }
     805              : 
     806            0 :             if (!kReflection)
     807              :             {
     808            0 :                 kReflection = request->GetReflection();
     809              :             }
     810              : 
     811            0 :             if (kReflection->HasField(*request, fieldDescriptor))
     812              :             {
     813            0 :                 if (fdIterator == requestFds.end())
     814              :                 {
     815            0 :                     RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but one was supplied");
     816            0 :                     return false;
     817              :                 }
     818              : 
     819            0 :                 kReflection->SetInt32(request, fieldDescriptor, fdIterator->fd());
     820            0 :                 ++fdIterator;
     821              :             }
     822              :         }
     823              :     }
     824              : 
     825           15 :     if (fdIterator != requestFds.end())
     826              :     {
     827            0 :         RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
     828            0 :         return false;
     829              :     }
     830              : 
     831           15 :     return true;
     832              : }
     833              : 
     834              : // -----------------------------------------------------------------------------
     835              : /*!
     836              :     \internal
     837              : 
     838              :     Processes a message received on a client socket.
     839              : 
     840              :  */
     841           15 : void ServerImpl::processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
     842              :                                       const std::vector<FileDescriptor> &fds)
     843              : {
     844           15 :     RIALTO_IPC_LOG_DEBUG("processing client message of size %zu bytes (%zu fds) from client %" PRId64, dataLen,
     845              :                          fds.size(), client->id());
     846              : 
     847              :     // parse the message
     848           15 :     transport::MessageToServer message;
     849           15 :     if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
     850              :     {
     851            0 :         RIALTO_IPC_LOG_ERROR("invalid request");
     852            0 :         return;
     853              :     }
     854              : 
     855           15 :     if (message.has_call())
     856              :     {
     857           15 :         processMethodCall(client, message.call(), fds);
     858              :     }
     859              :     else
     860              :     {
     861            0 :         RIALTO_IPC_LOG_WARN("received unknown message type from client");
     862              :     }
     863           15 : }
     864              : 
     865              : // -----------------------------------------------------------------------------
     866              : /*!
     867              :     \internal
     868              : 
     869              :     Processes a method call requst from a client.
     870              : 
     871              :  */
     872           15 : void ServerImpl::processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
     873              :                                    const std::vector<FileDescriptor> &fds)
     874              : {
     875              :     // try and find the service with the given name
     876           15 :     const std::string &kServiceName = call.service_name();
     877           15 :     auto it = client->m_services.find(kServiceName);
     878           15 :     if (it == client->m_services.end())
     879              :     {
     880            0 :         RIALTO_IPC_LOG_ERROR("unknown service request '%s'", kServiceName.c_str());
     881              : 
     882            0 :         sendErrorReply(client, call.serial_id(), "Unknown service '%s'", kServiceName.c_str());
     883            0 :         return;
     884              :     }
     885              : 
     886           15 :     std::shared_ptr<google::protobuf::Service> service = it->second;
     887              : 
     888              :     // try and find the method
     889           15 :     const std::string &kMethodName = call.method_name();
     890           15 :     const google::protobuf::MethodDescriptor *kMethod = service->GetDescriptor()->FindMethodByName(kMethodName);
     891           15 :     if (!kMethod)
     892              :     {
     893            0 :         RIALTO_IPC_LOG_ERROR("no method with name '%s'", kMethodName.c_str());
     894              : 
     895            0 :         sendErrorReply(client, call.serial_id(), "Unknown method '%s'", kMethodName.c_str());
     896            0 :         return;
     897              :     }
     898              : 
     899              :     // check if the method is expecting a reply
     900           15 :     const bool kNoReply = kMethod->options().HasExtension(no_reply) && kMethod->options().GetExtension(no_reply);
     901              : 
     902              :     // parse the request data
     903           15 :     google::protobuf::Message *requestMessage = service->GetRequestPrototype(kMethod).New();
     904           15 :     if (!requestMessage->ParseFromString(call.request_message()))
     905              :     {
     906            0 :         RIALTO_IPC_LOG_ERROR("failed to parse method from array");
     907              :     }
     908           15 :     else if (!addRequestFileDescriptors(requestMessage, fds))
     909              :     {
     910            0 :         RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the request");
     911              :     }
     912              :     else
     913              :     {
     914           15 :         RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", call.serial_id(), kServiceName.c_str(),
     915              :                              kMethodName.c_str(), requestMessage->ShortDebugString().c_str());
     916              : 
     917           15 :         auto *controller = new ServerControllerImpl(client, call.serial_id());
     918              : 
     919           15 :         if (kNoReply)
     920              :         {
     921              :             // we should not send a reply for this call, so call the code to handle the
     922              :             // request, but no need to pass a controller, response or closure object
     923            1 :             static google::protobuf::internal::FunctionClosure0 nullClosure(&google::protobuf::DoNothing, false);
     924            1 :             service->CallMethod(kMethod, controller, requestMessage, nullptr, &nullClosure);
     925              : 
     926            1 :             delete controller;
     927              :         }
     928              :         else
     929              :         {
     930              :             // create a response
     931           14 :             google::protobuf::Message *responseMessage = service->GetResponsePrototype(kMethod).New();
     932              : 
     933              :             // this is finally where we call the service implementation to process the request
     934           14 :             service->CallMethod(kMethod, controller, requestMessage, responseMessage,
     935              :                                 google::protobuf::NewCallback(this, &ServerImpl::handleResponse, controller,
     936              :                                                               responseMessage));
     937              :         }
     938              :     }
     939              : 
     940           15 :     delete requestMessage;
     941              : }
     942              : 
     943              : // -----------------------------------------------------------------------------
     944              : /*!
     945              :     \internal
     946              :     \threadsafe
     947              : 
     948              :     Sends a message to the given client if still connected.
     949              : 
     950              :  */
     951           14 : void ServerImpl::sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg)
     952              : {
     953              :     // now take the lock (so the socket is not closed beneath us) and send the reply
     954           14 :     std::lock_guard<std::mutex> locker(m_clientsLock);
     955              : 
     956           14 :     auto it = m_clients.find(clientId);
     957           14 :     if (it == m_clients.end())
     958              :     {
     959            1 :         RIALTO_IPC_LOG_WARN("socket removed before error reply could be sent");
     960              :     }
     961           13 :     else if (it->second.sock < 0)
     962              :     {
     963            0 :         RIALTO_IPC_LOG_WARN("socket closed before error reply could be sent");
     964              :     }
     965           13 :     else if (!msg)
     966              :     {
     967            0 :         RIALTO_IPC_LOG_WARN("invalid msg to send on socket, ignoring");
     968              :     }
     969           13 :     else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, msg.get(), MSG_NOSIGNAL)) !=
     970           13 :              static_cast<ssize_t>(msg->msg_iov->iov_len))
     971              :     {
     972            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete error reply message");
     973              :     }
     974           14 : }
     975              : 
     976              : // -----------------------------------------------------------------------------
     977              : /*!
     978              :     \internal
     979              : 
     980              :     Sends an error back to the client with a reason string in \a format.
     981              : 
     982              :  */
     983            0 : void ServerImpl::sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
     984              : {
     985              :     // construct the error string
     986              :     va_list ap;
     987            0 :     va_start(ap, format);
     988              :     char reason[512];
     989            0 :     vsnprintf(reason, sizeof(reason), format, ap);
     990            0 :     va_end(ap);
     991              : 
     992              :     // construct the reply message
     993            0 :     auto msg = populateErrorReply(client, serialId, reason);
     994              : 
     995              :     // and send it
     996            0 :     sendReply(client->id(), msg);
     997              : }
     998              : 
     999              : // -----------------------------------------------------------------------------
    1000              : /*!
    1001              :     \internal
    1002              :     \threadsafe
    1003              : 
    1004              :     Called once the service handler code has processed the request and
    1005              :     completed it.  This may be called from a different thread if the service
    1006              :     implementation decided to off-load the request to a processing thread.
    1007              : 
    1008              :  */
    1009           14 : void ServerImpl::handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response)
    1010              : {
    1011           14 :     if (!controller || !controller->m_kClient)
    1012              :     {
    1013            0 :         RIALTO_IPC_LOG_ERROR("missing controller or attached client");
    1014            0 :         return;
    1015              :     }
    1016              : 
    1017           14 :     const std::shared_ptr<const ClientImpl> kClient = controller->m_kClient;
    1018           14 :     const uint64_t kClientId = kClient->id();
    1019              : 
    1020           14 :     std::shared_ptr<msghdr> message;
    1021           14 :     if (!controller->m_failed)
    1022              :     {
    1023            8 :         message = populateReply(kClient, controller->m_kSerialId, response);
    1024              :     }
    1025              :     else
    1026              :     {
    1027            6 :         message = populateErrorReply(kClient, controller->m_kSerialId, controller->m_failureReason);
    1028              :     }
    1029              : 
    1030              :     // no longer need the controller or the response objects
    1031           14 :     delete response;
    1032           14 :     delete controller;
    1033              : 
    1034              :     // send the reply message to the given client
    1035           14 :     sendReply(kClientId, message);
    1036              : }
    1037              : 
    1038              : // -----------------------------------------------------------------------------
    1039              : /*!
    1040              :     \internal
    1041              : 
    1042              :     Gets all the file descriptors that are stored in the \a response message
    1043              :     and returns them as a vector of ints.
    1044              : 
    1045              :     It works by iterating over the fields in the message, finding ones that are
    1046              :     marked as 'field_is_fd' and inserting the fd value into control part of the
    1047              :     message.  The actual integer sent in the data part of the message is set to
    1048              :     -1.
    1049              : 
    1050              :  */
    1051           12 : static std::vector<int> getResponseFileDescriptors(google::protobuf::Message *response)
    1052              : {
    1053           12 :     std::vector<int> fds;
    1054              : 
    1055              :     // process any file descriptors from the response message
    1056           12 :     const google::protobuf::Descriptor *kDescriptor = response->GetDescriptor();
    1057           12 :     const google::protobuf::Reflection *kReflection = nullptr;
    1058              : 
    1059           12 :     const int n = kDescriptor->field_count();
    1060           25 :     for (int i = 0; i < n; i++)
    1061              :     {
    1062           13 :         auto fieldDescriptor = kDescriptor->field(i);
    1063           13 :         if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
    1064              :         {
    1065            0 :             if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
    1066              :             {
    1067            0 :                 RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
    1068            0 :                 return {};
    1069              :             }
    1070              : 
    1071            0 :             if (!kReflection)
    1072              :             {
    1073            0 :                 kReflection = response->GetReflection();
    1074              :             }
    1075              : 
    1076            0 :             if (kReflection->HasField(*response, fieldDescriptor))
    1077              :             {
    1078            0 :                 fds.push_back(kReflection->GetInt32(*response, fieldDescriptor));
    1079            0 :                 kReflection->SetInt32(response, fieldDescriptor, -1);
    1080              :             }
    1081              :         }
    1082              :     }
    1083              : 
    1084           12 :     return fds;
    1085              : }
    1086              : 
    1087              : // -----------------------------------------------------------------------------
    1088              : /*!
    1089              :     \internal
    1090              :     \static
    1091              : 
    1092              :     Populates tha socket message buffer with the reply data for the RPC request.
    1093              : 
    1094              :  */
    1095            8 : std::shared_ptr<msghdr> ServerImpl::populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
    1096              :                                                   google::protobuf::Message *response)
    1097              : {
    1098              :     // create the base reply
    1099            8 :     transport::MessageFromServer message;
    1100            8 :     transport::MethodCallReply *reply = message.mutable_reply();
    1101            8 :     if (!reply)
    1102              :     {
    1103            0 :         RIALTO_IPC_LOG_ERROR("failed to create mutable reply object");
    1104            0 :         return nullptr;
    1105              :     }
    1106              : 
    1107              :     // convert the message response to a data string
    1108            8 :     std::string respString = response->SerializeAsString();
    1109              : 
    1110              :     // wrap in a transport response and send that
    1111            8 :     reply->set_reply_id(serialId);
    1112            8 :     reply->set_reply_message(std::move(respString));
    1113              : 
    1114              :     // next need to check if the response message has any file descriptors in
    1115              :     // it that need to be attached
    1116            8 :     const std::vector<int> kFds = getResponseFileDescriptors(response);
    1117            8 :     const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
    1118              : 
    1119              :     // calculate the size of the reply
    1120            8 :     const size_t kRequiredDataLen = message.ByteSizeLong();
    1121            8 :     if (kRequiredDataLen > m_kMaxMessageLen)
    1122              :     {
    1123            0 :         RIALTO_IPC_LOG_ERROR("reply exceeds maximum message limit (%zu, max %zu)", kRequiredDataLen, m_kMaxMessageLen);
    1124              : 
    1125              :         // error message is too big, replace with a generic error
    1126            0 :         return populateErrorReply(client, serialId, "Internal error - reply message to large");
    1127              :     }
    1128              : 
    1129              :     // build the socket message to send
    1130              :     auto msgBuf =
    1131            8 :         m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
    1132              : 
    1133            8 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1134            8 :     bzero(header, sizeof(msghdr));
    1135              : 
    1136            8 :     auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
    1137            8 :     header->msg_control = ctrl;
    1138            8 :     header->msg_controllen = kRequiredCtrlLen;
    1139              : 
    1140            8 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
    1141            8 :     header->msg_iov = iov;
    1142            8 :     header->msg_iovlen = 1;
    1143              : 
    1144            8 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
    1145            8 :     iov->iov_base = data;
    1146            8 :     iov->iov_len = kRequiredDataLen;
    1147              : 
    1148              :     // copy in the data
    1149            8 :     message.SerializeWithCachedSizesToArray(data);
    1150              : 
    1151              :     // add the fds
    1152            8 :     if (!kFds.empty())
    1153              :     {
    1154            0 :         struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
    1155            0 :         if (!cmsg)
    1156              :         {
    1157            0 :             RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
    1158            0 :             return nullptr;
    1159              :         }
    1160              : 
    1161            0 :         cmsg->cmsg_level = SOL_SOCKET;
    1162            0 :         cmsg->cmsg_type = SCM_RIGHTS;
    1163            0 :         cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
    1164            0 :         memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
    1165            0 :         header->msg_controllen = cmsg->cmsg_len;
    1166              :     }
    1167              : 
    1168            8 :     RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - { %s }", serialId, response->ShortDebugString().c_str());
    1169              : 
    1170              :     // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
    1171              :     // now do it manually
    1172            8 :     return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
    1173              : }
    1174              : 
    1175              : // -----------------------------------------------------------------------------
    1176              : /*!
    1177              :     \internal
    1178              : 
    1179              :     Populates tha socket message buffer with a reply error message.
    1180              : 
    1181              :  */
    1182            6 : std::shared_ptr<msghdr> ServerImpl::populateErrorReply(const std::shared_ptr<const ClientImpl> &client,
    1183              :                                                        uint64_t serialId, const std::string &reason)
    1184              : {
    1185              :     // create the base reply
    1186            6 :     transport::MessageFromServer message;
    1187            6 :     transport::MethodCallError *error = message.mutable_error();
    1188            6 :     error->set_reply_id(serialId);
    1189            6 :     error->set_error_reason(reason);
    1190              : 
    1191              :     // check the message will fit
    1192            6 :     size_t replySize = message.ByteSizeLong();
    1193            6 :     if (replySize > m_kMaxMessageLen)
    1194              :     {
    1195            0 :         RIALTO_IPC_LOG_ERROR("error reply exceeds max message size");
    1196              : 
    1197              :         // error message is to big, replace with a generic error
    1198            0 :         error->set_error_reason("Error message truncated");
    1199            0 :         replySize = message.ByteSizeLong();
    1200              :     }
    1201              : 
    1202            6 :     RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - \"%s\"", serialId, reason.c_str());
    1203              : 
    1204              :     // construct the message to send on the socket
    1205            6 :     auto msgBuf = m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + replySize);
    1206              : 
    1207            6 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1208            6 :     bzero(header, sizeof(msghdr));
    1209              : 
    1210            6 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr));
    1211            6 :     header->msg_iov = iov;
    1212            6 :     header->msg_iovlen = 1;
    1213              : 
    1214            6 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + sizeof(iovec));
    1215            6 :     iov->iov_base = data;
    1216            6 :     iov->iov_len = replySize;
    1217              : 
    1218              :     // serialise the reply to the message buffer
    1219            6 :     message.SerializeWithCachedSizesToArray(data);
    1220              : 
    1221              :     // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
    1222              :     // now do it manually
    1223           12 :     return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
    1224            6 : }
    1225              : 
    1226              : // -----------------------------------------------------------------------------
    1227              : /*!
    1228              :     \threadsafe
    1229              : 
    1230              :     Returns \c true if the client with \a clientId is currently connected to
    1231              :     the server.
    1232              : 
    1233              :  */
    1234           53 : bool ServerImpl::isClientConnected(uint64_t clientId) const
    1235              : {
    1236           53 :     std::unique_lock<std::mutex> locker(m_clientsLock);
    1237          106 :     return (m_clients.count(clientId) > 0);
    1238           53 : }
    1239              : 
    1240              : // -----------------------------------------------------------------------------
    1241              : /*!
    1242              :     \threadsafe
    1243              : 
    1244              :     May be called internally when there is an error on the socket, or externally
    1245              :     (possibly from a different thread) if the handler or service code decides to
    1246              :     close the client connection.
    1247              : 
    1248              :  */
    1249           23 : void ServerImpl::disconnectClient(uint64_t clientId)
    1250              : {
    1251           23 :     std::unique_lock<std::mutex> locker(m_clientsLock);
    1252           23 :     m_condemnedClients.insert(clientId);
    1253           23 :     locker.unlock();
    1254              : 
    1255           23 :     wakeEventLoop();
    1256              : }
    1257              : 
    1258              : // -----------------------------------------------------------------------------
    1259              : /*!
    1260              :     \threadsafe
    1261              : 
    1262              :     Called via the IAVBusClient interface when an async event should be sent.
    1263              : 
    1264              :     This may be called from any thread, or from within the rpc message handler.
    1265              : 
    1266              :     The \a clientId is the client to send the event to.
    1267              : 
    1268              :  */
    1269            4 : bool ServerImpl::sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &eventMessage)
    1270              : {
    1271              :     // gets the file descriptors from the event message
    1272            4 :     const std::vector<int> kFds = getResponseFileDescriptors(eventMessage.get());
    1273            4 :     const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
    1274              : 
    1275              :     // create the base reply
    1276            4 :     transport::MessageFromServer message;
    1277            4 :     transport::EventFromServer *event = message.mutable_event();
    1278            4 :     if (!event)
    1279              :     {
    1280            0 :         RIALTO_IPC_LOG_ERROR("failed to create mutable event object");
    1281            0 :         return false;
    1282              :     }
    1283              : 
    1284            4 :     event->set_event_name(eventMessage->GetTypeName());
    1285              : 
    1286              :     // convert the event to a data string
    1287            4 :     std::string respString = eventMessage->SerializeAsString();
    1288              : 
    1289              :     // wrap in a transport response and send that
    1290            4 :     event->set_message(std::move(respString));
    1291              : 
    1292              :     // check the reply will fit
    1293            4 :     size_t requiredDataLen = message.ByteSizeLong();
    1294            4 :     if (requiredDataLen > m_kMaxMessageLen)
    1295              :     {
    1296            0 :         RIALTO_IPC_LOG_ERROR("event message to big to fit in buffer (size %zu, max size %zu)", requiredDataLen,
    1297              :                              m_kMaxMessageLen);
    1298            0 :         return false;
    1299              :     }
    1300              : 
    1301              :     // build the socket message to send
    1302              :     auto msgBuf =
    1303            4 :         m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + requiredDataLen);
    1304              : 
    1305            4 :     auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
    1306            4 :     bzero(header, sizeof(msghdr));
    1307              : 
    1308            4 :     auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
    1309            4 :     header->msg_control = ctrl;
    1310            4 :     header->msg_controllen = kRequiredCtrlLen;
    1311              : 
    1312            4 :     auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
    1313            4 :     header->msg_iov = iov;
    1314            4 :     header->msg_iovlen = 1;
    1315              : 
    1316            4 :     auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
    1317            4 :     iov->iov_base = data;
    1318            4 :     iov->iov_len = requiredDataLen;
    1319              : 
    1320              :     // copy in the data
    1321            4 :     message.SerializeWithCachedSizesToArray(data);
    1322              : 
    1323              :     // add the fds
    1324            4 :     if (!kFds.empty())
    1325              :     {
    1326            0 :         struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
    1327            0 :         if (!cmsg)
    1328              :         {
    1329            0 :             RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
    1330            0 :             return false;
    1331              :         }
    1332              : 
    1333            0 :         cmsg->cmsg_level = SOL_SOCKET;
    1334            0 :         cmsg->cmsg_type = SCM_RIGHTS;
    1335            0 :         cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
    1336            0 :         memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
    1337            0 :         header->msg_controllen = cmsg->cmsg_len;
    1338              :     }
    1339              : 
    1340              :     // finally, take the lock (so the socket is not closed beneath us) and send the reply
    1341            4 :     std::unique_lock<std::mutex> locker(m_clientsLock);
    1342              : 
    1343            4 :     auto it = m_clients.find(clientId);
    1344            4 :     if (it == m_clients.end() || it->second.sock < 0)
    1345              :     {
    1346            0 :         RIALTO_IPC_LOG_WARN("socket closed before event could be sent");
    1347            0 :         return false;
    1348              :     }
    1349            4 :     else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, header, MSG_NOSIGNAL)) != static_cast<ssize_t>(requiredDataLen))
    1350              :     {
    1351            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete event message");
    1352            0 :         return false;
    1353              :     }
    1354              : 
    1355            4 :     locker.unlock();
    1356              : 
    1357            4 :     RIALTO_IPC_LOG_DEBUG("event{ %s } - { %s }", eventMessage->GetTypeName().c_str(),
    1358              :                          eventMessage->ShortDebugString().c_str());
    1359              : 
    1360            4 :     return true;
    1361              : }
    1362              : } // namespace firebolt::rialto::ipc
        

Generated by: LCOV version 2.0-1