LCOV - code coverage report
Current view: top level - ipc/common/source - NamedSocket.cpp (source / functions) Coverage Total Hit
Test: coverage.info Lines: 73.2 % 149 109
Test Date: 2025-03-21 11:02:39 Functions: 100.0 % 15 15

            Line data    Source code
       1              : /*
       2              :  * If not stated otherwise in this file or this component's LICENSE file the
       3              :  * following copyright and licenses apply:
       4              :  *
       5              :  * Copyright 2025 Sky UK
       6              :  *
       7              :  * Licensed under the Apache License, Version 2.0 (the "License");
       8              :  * you may not use this file except in compliance with the License.
       9              :  * You may obtain a copy of the License at
      10              :  *
      11              :  * http://www.apache.org/licenses/LICENSE-2.0
      12              :  *
      13              :  * Unless required by applicable law or agreed to in writing, software
      14              :  * distributed under the License is distributed on an "AS IS" BASIS,
      15              :  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      16              :  * See the License for the specific language governing permissions and
      17              :  * limitations under the License.
      18              :  */
      19              : 
      20              : #include "NamedSocket.h"
      21              : #include "IpcLogging.h"
      22              : #include <grp.h>
      23              : #include <pwd.h>
      24              : #include <stdexcept>
      25              : #include <sys/file.h>
      26              : #include <sys/socket.h>
      27              : #include <sys/stat.h>
      28              : #include <sys/un.h>
      29              : #include <unistd.h>
      30              : #include <utility>
      31              : 
      32              : namespace
      33              : {
      34              : constexpr uid_t kNoOwnerChange = -1; // -1 means chown() won't change the owner
      35              : constexpr gid_t kNoGroupChange = -1; // -1 means chown() won't change the group
      36              : } // namespace
      37              : 
      38              : namespace firebolt::rialto::ipc
      39              : {
      40            6 : INamedSocketFactory &INamedSocketFactory::getFactory()
      41              : {
      42            6 :     static NamedSocketFactory factory;
      43            6 :     return factory;
      44              : }
      45              : 
      46            3 : std::unique_ptr<INamedSocket> NamedSocketFactory::createNamedSocket() const
      47              : try
      48              : {
      49            3 :     return std::make_unique<NamedSocket>();
      50              : }
      51            0 : catch (const std::runtime_error &error)
      52              : {
      53            0 :     RIALTO_IPC_LOG_ERROR("Failed to create named socket: %s", error.what());
      54            0 :     return nullptr;
      55              : }
      56              : 
      57            3 : std::unique_ptr<INamedSocket> NamedSocketFactory::createNamedSocket(const std::string &socketPath) const
      58              : try
      59              : {
      60            3 :     return std::make_unique<NamedSocket>(socketPath);
      61              : }
      62            0 : catch (const std::runtime_error &error)
      63              : {
      64            0 :     RIALTO_IPC_LOG_ERROR("Failed to create named socket: %s", error.what());
      65            0 :     return nullptr;
      66              : }
      67              : 
      68            3 : NamedSocket::NamedSocket()
      69              : {
      70            3 :     RIALTO_IPC_LOG_MIL("Creating new socket without binding");
      71              : 
      72              :     // Create the socket
      73            3 :     m_sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK, 0);
      74            3 :     if (m_sockFd == -1)
      75              :     {
      76            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
      77            0 :         throw std::runtime_error("socket error");
      78              :     }
      79              : 
      80            3 :     RIALTO_IPC_LOG_MIL("Socket created, fd: %d", m_sockFd);
      81              : }
      82              : 
      83            3 : NamedSocket::NamedSocket(const std::string &socketPath)
      84              : {
      85            3 :     RIALTO_IPC_LOG_MIL("Creating named socket with name: %s", socketPath.c_str());
      86            3 :     m_sockPath = socketPath;
      87              : 
      88              :     // Create the socket
      89            3 :     m_sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK, 0);
      90            3 :     if (m_sockFd == -1)
      91              :     {
      92            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
      93            0 :         throw std::runtime_error("socket error");
      94              :     }
      95              : 
      96              :     // get the socket lock
      97            3 :     if (!getSocketLock())
      98              :     {
      99            0 :         closeListeningSocket();
     100            0 :         throw std::runtime_error("lock error");
     101              :     }
     102              : 
     103              :     // bind to the given path
     104            3 :     struct sockaddr_un addr = {0};
     105            3 :     memset(&addr, 0x00, sizeof(addr));
     106            3 :     addr.sun_family = AF_UNIX;
     107            3 :     strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
     108              : 
     109            3 :     if (::bind(m_sockFd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
     110              :     {
     111            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
     112              : 
     113            0 :         closeListeningSocket();
     114            0 :         throw std::runtime_error("bind error");
     115              :     }
     116              : 
     117            3 :     RIALTO_IPC_LOG_MIL("Named socket with name: %s created, fd: %d", m_sockPath.c_str(), m_sockFd);
     118              : }
     119              : 
     120           12 : NamedSocket::~NamedSocket()
     121              : {
     122            6 :     RIALTO_IPC_LOG_MIL("Close named socket with name: %s, fd: %d", m_sockPath.c_str(), m_sockFd);
     123            6 :     closeListeningSocket();
     124           12 : }
     125              : 
     126            4 : int NamedSocket::getFd() const
     127              : {
     128            4 :     return m_sockFd;
     129              : }
     130              : 
     131            1 : bool NamedSocket::setSocketPermissions(unsigned int socketPermissions) const
     132              : {
     133            1 :     errno = 0;
     134            1 :     if (chmod(m_sockPath.c_str(), socketPermissions) != 0)
     135              :     {
     136            0 :         RIALTO_IPC_LOG_SYS_WARN(errno, "Failed to change the permissions on the IPC socket");
     137            0 :         return false;
     138              :     }
     139            1 :     return true;
     140              : }
     141              : 
     142            1 : bool NamedSocket::setSocketOwnership(const std::string &socketOwner, const std::string &socketGroup) const
     143              : {
     144            1 :     uid_t ownerId = getSocketOwnerId(socketOwner);
     145            1 :     gid_t groupId = getSocketGroupId(socketGroup);
     146              : 
     147            1 :     if (ownerId != kNoOwnerChange || groupId != kNoGroupChange)
     148              :     {
     149            1 :         errno = 0;
     150            1 :         if (chown(m_sockPath.c_str(), ownerId, groupId) != 0)
     151              :         {
     152            1 :             RIALTO_IPC_LOG_SYS_WARN(errno, "Failed to change the owner/group for the IPC socket");
     153              :         }
     154              :     }
     155            1 :     return true;
     156              : }
     157              : 
     158            2 : bool NamedSocket::blockNewConnections() const
     159              : {
     160            2 :     if (m_sockPath.empty())
     161              :     {
     162            1 :         RIALTO_IPC_LOG_DEBUG("No need to block new connections - socket not configured");
     163            1 :         return true;
     164              :     }
     165            1 :     RIALTO_IPC_LOG_INFO("Block new connections for: %s", m_sockPath.c_str());
     166            1 :     if (listen(m_sockFd, 0) == -1)
     167              :     {
     168            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "blockNewConnections: listen error");
     169            0 :         return false;
     170              :     }
     171            1 :     return true;
     172              : }
     173              : 
     174            3 : bool NamedSocket::bind(const std::string &socketPath)
     175              : {
     176            3 :     if (!m_sockPath.empty())
     177              :     {
     178            1 :         RIALTO_IPC_LOG_DEBUG("no need to bind again");
     179            1 :         return true;
     180              :     }
     181            2 :     RIALTO_IPC_LOG_MIL("Binding socket with fd: %d with name: %s", m_sockFd, socketPath.c_str());
     182            2 :     m_sockPath = socketPath;
     183              : 
     184              :     // get the socket lock
     185            2 :     if (!getSocketLock())
     186              :     {
     187            0 :         closeListeningSocket();
     188            0 :         return false;
     189              :     }
     190              : 
     191              :     // bind to the given path
     192            2 :     struct sockaddr_un addr = {0};
     193            2 :     memset(&addr, 0x00, sizeof(addr));
     194            2 :     addr.sun_family = AF_UNIX;
     195            2 :     strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
     196              : 
     197            2 :     if (::bind(m_sockFd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
     198              :     {
     199            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
     200              : 
     201            0 :         closeListeningSocket();
     202            0 :         return false;
     203              :     }
     204              : 
     205            2 :     RIALTO_IPC_LOG_MIL("Named socket with fd: %d bound with path: %s", m_sockFd, m_sockPath.c_str());
     206              : 
     207            2 :     return true;
     208              : }
     209              : 
     210            6 : void NamedSocket::closeListeningSocket()
     211              : {
     212            6 :     if (!m_sockPath.empty() && (unlink(m_sockPath.c_str()) != 0) && (errno != ENOENT))
     213            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", m_sockPath.c_str());
     214            6 :     if ((m_sockFd >= 0) && (close(m_sockFd) != 0))
     215            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
     216              : 
     217            6 :     if (!m_lockPath.empty() && (unlink(m_lockPath.c_str()) != 0) && (errno != ENOENT))
     218            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", m_lockPath.c_str());
     219            6 :     if ((m_lockFd >= 0) && (close(m_lockFd) != 0))
     220            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
     221              : 
     222            6 :     m_sockFd = -1;
     223            6 :     m_sockPath.clear();
     224              : 
     225            6 :     m_lockFd = -1;
     226            6 :     m_lockPath.clear();
     227              : }
     228              : 
     229            5 : bool NamedSocket::getSocketLock()
     230              : {
     231            5 :     std::string lockPath = m_sockPath + ".lock";
     232            5 :     int fd = open(lockPath.c_str(), O_CREAT | O_CLOEXEC, (S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP));
     233            5 :     if (fd < 0)
     234              :     {
     235            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create / open lockfile @ '%s' (check permissions)", lockPath.c_str());
     236            0 :         return false;
     237              :     }
     238              : 
     239            5 :     if (flock(fd, LOCK_EX | LOCK_NB) < 0)
     240              :     {
     241            0 :         RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to lock lockfile @ '%s', maybe another server is running",
     242              :                                  lockPath.c_str());
     243            0 :         close(fd);
     244            0 :         return false;
     245              :     }
     246              : 
     247            5 :     struct stat sbuf = {0};
     248            5 :     if (stat(m_sockPath.c_str(), &sbuf) < 0)
     249              :     {
     250            5 :         if (errno != ENOENT)
     251              :         {
     252            0 :             RIALTO_IPC_LOG_SYS_ERROR(errno, "did not manage to stat existing socket @ '%s'", m_sockPath.c_str());
     253            0 :             close(fd);
     254            0 :             return false;
     255              :         }
     256              :     }
     257            0 :     else if ((sbuf.st_mode & S_IWUSR) || (sbuf.st_mode & S_IWGRP))
     258              :     {
     259            0 :         unlink(m_sockPath.c_str());
     260              :     }
     261              : 
     262            5 :     m_lockFd = fd;
     263            5 :     m_lockPath = std::move(lockPath);
     264              : 
     265            5 :     return true;
     266              : }
     267              : 
     268            1 : uid_t NamedSocket::getSocketOwnerId(const std::string &socketOwner) const
     269              : {
     270            1 :     uid_t ownerId = kNoOwnerChange;
     271            1 :     const size_t kBufferSize = sysconf(_SC_GETPW_R_SIZE_MAX);
     272            1 :     if (!socketOwner.empty() && kBufferSize > 0)
     273              :     {
     274            1 :         errno = 0;
     275            1 :         passwd passwordStruct{};
     276            1 :         passwd *passwordResult = nullptr;
     277            1 :         char buffer[kBufferSize];
     278            1 :         int result = getpwnam_r(socketOwner.c_str(), &passwordStruct, buffer, kBufferSize, &passwordResult);
     279            1 :         if (result == 0 && passwordResult)
     280              :         {
     281            1 :             ownerId = passwordResult->pw_uid;
     282              :         }
     283              :         else
     284              :         {
     285            0 :             RIALTO_IPC_LOG_SYS_WARN(errno, "Failed to determine ownerId for '%s'", socketOwner.c_str());
     286              :         }
     287            1 :     }
     288            1 :     return ownerId;
     289              : }
     290              : 
     291            1 : gid_t NamedSocket::getSocketGroupId(const std::string &socketGroup) const
     292              : {
     293            1 :     gid_t groupId = kNoGroupChange;
     294            1 :     const size_t kBufferSize = sysconf(_SC_GETPW_R_SIZE_MAX);
     295            1 :     if (!socketGroup.empty() && kBufferSize > 0)
     296              :     {
     297            1 :         errno = 0;
     298            1 :         group groupStruct{};
     299            1 :         group *groupResult = nullptr;
     300            1 :         char buffer[kBufferSize];
     301            1 :         int result = getgrnam_r(socketGroup.c_str(), &groupStruct, buffer, kBufferSize, &groupResult);
     302            1 :         if (result == 0 && groupResult)
     303              :         {
     304            1 :             groupId = groupResult->gr_gid;
     305              :         }
     306              :         else
     307              :         {
     308            0 :             RIALTO_IPC_LOG_SYS_WARN(errno, "Failed to determine groupId for '%s'", socketGroup.c_str());
     309              :         }
     310            1 :     }
     311            1 :     return groupId;
     312              : }
     313              : } // namespace firebolt::rialto::ipc
        

Generated by: LCOV version 2.0-1