Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #include "IpcServerImpl.h"
21 : #include "IIpcServerFactory.h"
22 : #include "IpcClientImpl.h"
23 : #include "IpcLogging.h"
24 : #include "IpcServerControllerImpl.h"
25 :
26 : #include "rialtoipc.pb.h"
27 :
28 : #include <google/protobuf/service.h>
29 :
30 : #include <algorithm>
31 : #include <cinttypes>
32 : #include <cstdarg>
33 : #include <string>
34 : #include <utility>
35 : #include <vector>
36 :
37 : #include <fcntl.h>
38 : #include <poll.h>
39 : #include <sys/epoll.h>
40 : #include <sys/eventfd.h>
41 : #include <sys/file.h>
42 : #include <sys/socket.h>
43 : #include <sys/stat.h>
44 : #include <sys/un.h>
45 : #include <unistd.h>
46 :
47 : #define WAKE_EVENT_ID uint64_t(0)
48 : #define FIRST_LISTENING_SOCKET_ID uint64_t(1)
49 : #define FIRST_CLIENT_ID uint64_t(10000)
50 :
51 : namespace firebolt::rialto::ipc
52 : {
53 : const size_t ServerImpl::m_kMaxMessageLen = (128 * 1024);
54 :
55 26 : std::shared_ptr<IServerFactory> IServerFactory::createFactory()
56 : {
57 26 : std::shared_ptr<IServerFactory> factory;
58 : try
59 : {
60 26 : factory = std::make_shared<ServerFactory>();
61 : }
62 0 : catch (const std::exception &e)
63 : {
64 0 : RIALTO_IPC_LOG_ERROR("Failed to create the server factory, reason: %s", e.what());
65 : }
66 :
67 26 : return factory;
68 : }
69 :
70 26 : std::shared_ptr<IServer> ServerFactory::create()
71 : {
72 26 : return std::make_shared<ServerImpl>();
73 : }
74 :
75 26 : ServerImpl::ServerImpl()
76 26 : : m_pollFd(-1), m_wakeEventFd(-1), m_socketIdCounter(FIRST_LISTENING_SOCKET_ID),
77 26 : m_clientIdCounter(FIRST_CLIENT_ID), m_recvDataBuf{0}, m_recvCtrlBuf{0}
78 : {
79 : // create the eventfd use to wake the poll loop
80 26 : m_wakeEventFd = eventfd(0, EFD_CLOEXEC);
81 26 : if (m_wakeEventFd < 0)
82 : {
83 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd failed");
84 0 : return;
85 : }
86 :
87 : // create epoll loop
88 26 : m_pollFd = epoll_create1(EPOLL_CLOEXEC);
89 26 : if (m_pollFd < 0)
90 : {
91 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_create1 failed");
92 0 : return;
93 : }
94 :
95 : // add the wake event to epoll
96 26 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = WAKE_EVENT_ID}};
97 26 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, m_wakeEventFd, &event) != 0)
98 : {
99 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
100 : }
101 : }
102 :
103 26 : ServerImpl::~ServerImpl()
104 : {
105 26 : if ((m_pollFd >= 0) && (close(m_pollFd) != 0))
106 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close epoll");
107 :
108 26 : if ((m_wakeEventFd >= 0) && (close(m_wakeEventFd) != 0))
109 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close eventfd");
110 :
111 35 : for (const auto &entry : m_sockets)
112 : {
113 9 : const Socket &kSocket = entry.second;
114 :
115 9 : if (unlink(kSocket.sockPath.c_str()) != 0)
116 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", kSocket.sockPath.c_str());
117 9 : if (close(kSocket.sockFd) != 0)
118 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
119 :
120 9 : if (unlink(kSocket.lockPath.c_str()) != 0)
121 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", kSocket.lockPath.c_str());
122 9 : if (close(kSocket.lockFd) != 0)
123 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
124 : }
125 26 : }
126 :
127 : // -----------------------------------------------------------------------------
128 : /*!
129 : \static
130 : \internal
131 :
132 : Creates (if required) and takes the file lock associated with the socket
133 : path in the \a socket object.
134 :
135 : */
136 9 : bool ServerImpl::getSocketLock(Socket *socket)
137 : {
138 9 : std::string lockPath = socket->sockPath + ".lock";
139 9 : int fd = open(lockPath.c_str(), O_CREAT | O_CLOEXEC, (S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP));
140 9 : if (fd < 0)
141 : {
142 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create / open lockfile @ '%s' (check permissions)", lockPath.c_str());
143 0 : return false;
144 : }
145 :
146 9 : if (flock(fd, LOCK_EX | LOCK_NB) < 0)
147 : {
148 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to lock lockfile @ '%s', maybe another server is running",
149 : lockPath.c_str());
150 0 : close(fd);
151 0 : return false;
152 : }
153 :
154 9 : struct stat sbuf = {0};
155 9 : if (stat(socket->sockPath.c_str(), &sbuf) < 0)
156 : {
157 9 : if (errno != ENOENT)
158 : {
159 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "did not manage to stat existing socket @ '%s'", socket->sockPath.c_str());
160 0 : close(fd);
161 0 : return false;
162 : }
163 : }
164 0 : else if ((sbuf.st_mode & S_IWUSR) || (sbuf.st_mode & S_IWGRP))
165 : {
166 0 : unlink(socket->sockPath.c_str());
167 : }
168 :
169 9 : socket->lockFd = fd;
170 9 : socket->lockPath = std::move(lockPath);
171 :
172 9 : return true;
173 : }
174 :
175 : // -----------------------------------------------------------------------------
176 : /*!
177 : \static
178 : \internal
179 :
180 : Closes the file descriptors and the deletes the files stored in the \a socket
181 : object.
182 :
183 : */
184 0 : void ServerImpl::closeListeningSocket(Socket *socket)
185 : {
186 0 : if (!socket->sockPath.empty() && (unlink(socket->sockPath.c_str()) != 0) && (errno != ENOENT))
187 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", socket->sockPath.c_str());
188 0 : if ((socket->sockFd >= 0) && (close(socket->sockFd) != 0))
189 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
190 :
191 0 : if (!socket->lockPath.empty() && (unlink(socket->lockPath.c_str()) != 0) && (errno != ENOENT))
192 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", socket->lockPath.c_str());
193 0 : if ((socket->lockFd >= 0) && (close(socket->lockFd) != 0))
194 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
195 :
196 0 : socket->sockFd = -1;
197 0 : socket->sockPath.clear();
198 :
199 0 : socket->lockFd = -1;
200 0 : socket->lockPath.clear();
201 : }
202 :
203 9 : bool ServerImpl::addSocket(const std::string &socketPath,
204 : std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
205 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
206 : {
207 : // store the path
208 9 : Socket socket;
209 9 : socket.sockPath = socketPath;
210 :
211 : // create the socket
212 9 : socket.sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
213 9 : if (socket.sockFd == -1)
214 : {
215 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
216 0 : return false;
217 : }
218 :
219 : // get the socket lock
220 9 : if (!getSocketLock(&socket))
221 : {
222 0 : closeListeningSocket(&socket);
223 0 : return false;
224 : }
225 :
226 : // bind to the given path
227 9 : struct sockaddr_un addr = {0};
228 9 : memset(&addr, 0x00, sizeof(addr));
229 9 : addr.sun_family = AF_UNIX;
230 9 : strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
231 :
232 9 : if (bind(socket.sockFd, (struct sockaddr *)&addr, sizeof(addr)) == -1)
233 : {
234 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
235 :
236 0 : closeListeningSocket(&socket);
237 0 : return false;
238 : }
239 :
240 : // put in listening mode
241 9 : if (listen(socket.sockFd, 1) == -1)
242 : {
243 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
244 :
245 0 : closeListeningSocket(&socket);
246 0 : return false;
247 : }
248 :
249 : // create an id for the listening socket
250 9 : const uint64_t kSocketId = m_socketIdCounter++;
251 9 : if (kSocketId >= FIRST_CLIENT_ID)
252 : {
253 : // should never happen, we'd run out of file descriptors before
254 : // we hit the 10k limit on listening sockets
255 0 : RIALTO_IPC_LOG_ERROR("too many listening sockets");
256 :
257 0 : closeListeningSocket(&socket);
258 0 : return false;
259 : }
260 :
261 : // add the socket to epoll
262 9 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
263 9 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
264 : {
265 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
266 :
267 0 : closeListeningSocket(&socket);
268 0 : return false;
269 : }
270 :
271 : // store the client connected / disconnected callbacks
272 9 : socket.connectedCb = clientConnectedCb;
273 9 : socket.disconnectedCb = clientDisconnectedCb;
274 :
275 : // add to the internal map
276 9 : std::lock_guard<std::mutex> locker(m_socketsLock);
277 9 : m_sockets.emplace(kSocketId, std::move(socket));
278 :
279 9 : RIALTO_IPC_LOG_INFO("added listening socket '%s' to server", socketPath.c_str());
280 :
281 9 : return true;
282 : }
283 :
284 14 : std::shared_ptr<IClient> ServerImpl::addClient(int socketFd,
285 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
286 : {
287 : // sanity check the supplied socket is of the right type
288 : struct sockaddr addr;
289 14 : socklen_t len = sizeof(sockaddr);
290 14 : if ((getsockname(socketFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
291 : {
292 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
293 0 : return nullptr;
294 : }
295 14 : if (addr.sa_family != AF_UNIX)
296 : {
297 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
298 0 : return nullptr;
299 : }
300 :
301 14 : int type = 0;
302 14 : len = sizeof(type);
303 14 : if ((getsockopt(socketFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
304 : {
305 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
306 0 : return nullptr;
307 : }
308 14 : if (type != SOCK_SEQPACKET)
309 : {
310 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
311 0 : return nullptr;
312 : }
313 :
314 : // dup the socket and set the O_CLOEXEC bit
315 14 : int duppedFd = fcntl(socketFd, F_DUPFD_CLOEXEC, 3);
316 14 : if (duppedFd < 0)
317 : {
318 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
319 0 : return nullptr;
320 : }
321 :
322 : // ensure the SOCK_NONBLOCK flag is set
323 14 : int flags = fcntl(duppedFd, F_GETFL);
324 14 : if ((flags < 0) || (fcntl(duppedFd, F_SETFL, flags | O_NONBLOCK) < 0))
325 : {
326 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
327 0 : close(duppedFd);
328 0 : return nullptr;
329 : }
330 :
331 : // finally, add the socket to the list of clients
332 28 : auto client = addClientSocket(duppedFd, "", std::move(clientDisconnectedCb));
333 14 : if (!client)
334 : {
335 0 : close(duppedFd);
336 : }
337 :
338 14 : return client;
339 : }
340 :
341 0 : int ServerImpl::fd() const
342 : {
343 0 : return m_pollFd;
344 : }
345 :
346 : // -----------------------------------------------------------------------------
347 : /*!
348 : \internal
349 :
350 : Writes to the eventfd to wake the event loop. Typically called when requesting
351 : it to shutdown or external code has requested that a client be disconnected.
352 :
353 : */
354 23 : void ServerImpl::wakeEventLoop() const
355 : {
356 23 : if (m_wakeEventFd < 0)
357 : {
358 0 : RIALTO_IPC_LOG_ERROR("invalid wake event fd");
359 : }
360 : else
361 : {
362 23 : uint64_t value = 1;
363 23 : if (TEMP_FAILURE_RETRY(::write(m_wakeEventFd, &value, sizeof(value))) != sizeof(value))
364 : {
365 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to the event fd");
366 : }
367 : }
368 23 : }
369 :
370 2879 : bool ServerImpl::wait(int timeoutMSecs)
371 : {
372 2879 : if (m_pollFd < 0)
373 : {
374 0 : return false;
375 : }
376 :
377 : // wait for any event (with timeout)
378 : struct pollfd fds[2];
379 2879 : fds[0].fd = m_pollFd;
380 2879 : fds[0].events = POLLIN;
381 :
382 2879 : int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
383 2879 : if (rc < 0)
384 : {
385 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
386 0 : return false;
387 : }
388 :
389 2879 : return true;
390 : }
391 :
392 2902 : bool ServerImpl::process()
393 : {
394 2902 : if (m_pollFd < 0)
395 : {
396 0 : RIALTO_IPC_LOG_ERROR("missing epoll");
397 0 : return false;
398 : }
399 :
400 : // read up to 32 events
401 2902 : const int kMaxEvents = 32;
402 : struct epoll_event events[kMaxEvents];
403 :
404 2902 : int rc = TEMP_FAILURE_RETRY(epoll_wait(m_pollFd, events, kMaxEvents, 0));
405 2902 : if (rc < 0)
406 : {
407 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
408 0 : return false;
409 : }
410 :
411 : // process the events (maybe 0 if timed out)
412 2951 : for (int i = 0; i < rc; i++)
413 : {
414 49 : const struct epoll_event &kEvent = events[i];
415 :
416 : // check if a wake event, in which case just clear the eventfd
417 49 : if (kEvent.data.u64 == WAKE_EVENT_ID)
418 : {
419 : uint64_t ignore;
420 3 : if (TEMP_FAILURE_RETRY(::read(m_wakeEventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
421 : {
422 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to read wake eventfd");
423 : }
424 : }
425 :
426 : // check for events on the listening socket
427 46 : else if (kEvent.data.u64 < FIRST_CLIENT_ID)
428 : {
429 9 : if (kEvent.events & EPOLLIN)
430 9 : processNewConnection(kEvent.data.u64);
431 9 : if (kEvent.events & EPOLLERR)
432 0 : RIALTO_IPC_LOG_ERROR("error occurred on listening socket");
433 : }
434 :
435 : // otherwise, the event must have come from a socket
436 : else
437 : {
438 37 : processClientSocket(kEvent.data.u64, kEvent.events);
439 : }
440 : }
441 :
442 : // if we have client sockets that are condemned then we need to shut down
443 : // and close them as well as remove from epoll
444 2902 : std::unique_lock<std::mutex> locker(m_clientsLock);
445 2902 : if (!m_condemnedClients.empty())
446 : {
447 : // take a copy of the set so we can process without the lock held
448 23 : std::set<uint64_t> theCondemned;
449 23 : m_condemnedClients.swap(theCondemned);
450 :
451 46 : for (uint64_t clientId : theCondemned)
452 : {
453 23 : auto it = m_clients.find(clientId);
454 23 : if (it == m_clients.end())
455 : {
456 0 : RIALTO_IPC_LOG_ERROR("failed to find condemned client");
457 0 : continue;
458 : }
459 :
460 23 : ClientDetails details = it->second;
461 :
462 : // remove from the list of clients
463 23 : m_clients.erase(it);
464 :
465 : // drop the lock while closing the connection and removing from epoll
466 23 : locker.unlock();
467 :
468 : // remove the socket from epoll and close it
469 23 : if (details.sock >= 0)
470 : {
471 23 : if (epoll_ctl(m_pollFd, EPOLL_CTL_DEL, details.sock, nullptr) != 0)
472 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket from epoll");
473 :
474 23 : if (shutdown(details.sock, SHUT_RDWR) != 0)
475 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to shutdown socket");
476 23 : if (close(details.sock) != 0)
477 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket");
478 : }
479 :
480 : // let the installed handler know a client has disconnected
481 23 : if (details.disconnectedCb)
482 23 : details.disconnectedCb(details.client);
483 :
484 : // ensure client object is destructed without the client lock held
485 23 : details.client.reset();
486 :
487 : // re-take the lock for the next client
488 23 : locker.lock();
489 : }
490 : }
491 :
492 2902 : return true;
493 : }
494 :
495 : // -----------------------------------------------------------------------------
496 : /*!
497 : \internal
498 :
499 : Adds the \a socketFd to the internal list of client sockets. This is called
500 : when a new connection is accepted on a listening socket, or when a client
501 : fd is added via ServerImpl::addClient(...).
502 :
503 : Returns a nullptr if failed to add the socket.
504 :
505 : */
506 23 : std::shared_ptr<ClientImpl> ServerImpl::addClientSocket(int socketFd, const std::string &listeningSocketPath,
507 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb)
508 : {
509 : // get the client credentials
510 23 : struct ucred clientCreds = {0};
511 23 : socklen_t clientCredsLen = sizeof(clientCreds);
512 :
513 46 : if ((getsockopt(socketFd, SOL_SOCKET, SO_PEERCRED, &clientCreds, &clientCredsLen) < 0) ||
514 23 : (clientCredsLen != sizeof(clientCreds)))
515 : {
516 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get client's details");
517 0 : return nullptr;
518 : }
519 :
520 : // create a new unique client id for the connection
521 23 : const uint64_t kClientId = m_clientIdCounter++;
522 :
523 : // add the new socket to the poll loop
524 23 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kClientId}};
525 23 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socketFd, &event) != 0)
526 : {
527 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add client socket");
528 0 : return nullptr;
529 : }
530 :
531 : // create initial client object for the socket
532 23 : auto client = std::make_shared<ClientImpl>(shared_from_this(), kClientId, clientCreds);
533 :
534 : // and the details for the internal list
535 23 : ClientDetails clientDetails;
536 23 : clientDetails.sock = socketFd;
537 23 : clientDetails.disconnectedCb = std::move(disconnectedCb);
538 23 : clientDetails.client = client;
539 :
540 : // add to the set of clients
541 : {
542 23 : std::lock_guard<std::mutex> locker(m_clientsLock);
543 23 : m_clients.emplace(kClientId, clientDetails);
544 : }
545 :
546 23 : RIALTO_IPC_LOG_INFO("new client connected - giving id %" PRIu64, kClientId);
547 :
548 23 : return client;
549 : }
550 :
551 : // -----------------------------------------------------------------------------
552 : /*!
553 : \internal
554 :
555 : Called when an event occurs on the listening socket. The code first accepts
556 : the connection and retrieves the client details, it then calls the installed
557 : handler to determine if we should drop this connection or not.
558 :
559 :
560 : */
561 9 : void ServerImpl::processNewConnection(uint64_t socketId)
562 : {
563 9 : RIALTO_IPC_LOG_DEBUG("processing new connection");
564 :
565 9 : std::unique_lock<std::mutex> socketLocker(m_socketsLock);
566 :
567 : // find matching socket object
568 9 : auto it = m_sockets.find(socketId);
569 9 : if (it == m_sockets.end())
570 : {
571 0 : RIALTO_IPC_LOG_ERROR("failed to find listening socket with id %" PRIu64, socketId);
572 0 : return;
573 : }
574 :
575 9 : const Socket &kSocket = it->second;
576 :
577 : // accept the connection from the client
578 9 : struct sockaddr clientAddr = {0};
579 9 : socklen_t clientAddrLen = sizeof(clientAddr);
580 :
581 9 : int clientSock = accept4(kSocket.sockFd, &clientAddr, &clientAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC);
582 9 : if (clientSock < 0)
583 : {
584 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to accept client connection");
585 0 : return;
586 : }
587 :
588 9 : const std::string kSockPath = kSocket.sockPath;
589 9 : std::function<void(const std::shared_ptr<IClient> &)> connectedCb = kSocket.connectedCb;
590 9 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb = kSocket.disconnectedCb;
591 :
592 9 : socketLocker.unlock();
593 :
594 : // attempt to add the socket to the client list
595 9 : auto client = addClientSocket(clientSock, kSockPath, std::move(disconnectedCb));
596 9 : if (!client)
597 : {
598 0 : close(clientSock);
599 0 : return;
600 : }
601 :
602 : // notify the handler that a new connection has been made
603 9 : if (connectedCb)
604 : {
605 : // tell the handler we have a new client
606 9 : connectedCb(client);
607 : }
608 : }
609 :
610 : // -----------------------------------------------------------------------------
611 : /*!
612 : \internal
613 : \static
614 :
615 : Reads all the file descriptors from a message. It returns the file descriptors
616 : as a vector of FileDescriptor objects, these objects safely store the fd and
617 : close them when they're destructed.
618 :
619 : */
620 0 : static std::vector<FileDescriptor> readMessageFds(const struct msghdr *msg, size_t limit)
621 : {
622 0 : std::vector<FileDescriptor> fds;
623 :
624 0 : for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr; cmsg = CMSG_NXTHDR((struct msghdr *)msg, cmsg))
625 : {
626 0 : if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
627 : {
628 0 : const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
629 0 : if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
630 : {
631 0 : RIALTO_IPC_LOG_ERROR("invalid fd array size");
632 : }
633 : else
634 : {
635 0 : const size_t n = kFdsLength / sizeof(int);
636 0 : RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
637 :
638 0 : fds.reserve(std::min(limit, n));
639 :
640 0 : const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
641 0 : for (size_t i = 0; i < n; i++)
642 : {
643 0 : RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
644 :
645 0 : if (fds.size() >= limit)
646 : {
647 0 : RIALTO_IPC_LOG_ERROR("received to many file descriptors, "
648 : "exceeding max per message, closing left overs");
649 : }
650 : else
651 : {
652 0 : FileDescriptor fd(kFds[i]);
653 0 : if (!fd.isValid())
654 : {
655 0 : RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
656 : }
657 : else
658 : {
659 0 : fds.emplace_back(std::move(fd));
660 : }
661 : }
662 :
663 0 : if (close(kFds[i]) != 0)
664 : {
665 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
666 : }
667 : }
668 : }
669 : }
670 : }
671 :
672 0 : return fds;
673 : }
674 :
675 : // -----------------------------------------------------------------------------
676 : /*!
677 : \internal
678 :
679 : Processes an event from a client socket.
680 :
681 : */
682 37 : void ServerImpl::processClientSocket(uint64_t clientId, unsigned events)
683 : {
684 : // take the lock while accessing the client list
685 37 : std::unique_lock<std::mutex> locker(m_clientsLock);
686 :
687 37 : auto it = m_clients.find(clientId);
688 37 : if (it == m_clients.end())
689 : {
690 : // should never happen
691 0 : RIALTO_IPC_LOG_ERROR("received an event from a socket with no matching client");
692 0 : return;
693 : }
694 :
695 : // check if the client is marked for closure, if so then just ignore the data
696 37 : if (m_condemnedClients.count(clientId) != 0)
697 : {
698 0 : return;
699 : }
700 :
701 : // get the socket that corresponds to the client connection
702 37 : const int kSockFd = it->second.sock;
703 :
704 : // get the client object
705 37 : std::shared_ptr<ClientImpl> clientObj = it->second.client;
706 :
707 : // can safely release the lock now we have the clientId and client object
708 37 : locker.unlock();
709 :
710 : // if there was an error disconnect the socket
711 37 : if (events & EPOLLERR)
712 : {
713 0 : RIALTO_IPC_LOG_ERROR("error detected on client socket - disconnecting client");
714 0 : disconnectClient(clientId);
715 0 : return;
716 : }
717 :
718 37 : if (events & EPOLLIN)
719 : {
720 : // read all messages from the client socket, we break out if the socket is closed
721 : // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
722 : while (true)
723 : {
724 52 : struct msghdr msg = {nullptr};
725 52 : struct iovec io = {.iov_base = m_recvDataBuf, .iov_len = sizeof(m_recvDataBuf)};
726 :
727 52 : bzero(&msg, sizeof(msg));
728 52 : msg.msg_iov = &io;
729 52 : msg.msg_iovlen = 1;
730 52 : msg.msg_control = m_recvCtrlBuf;
731 52 : msg.msg_controllen = sizeof(m_recvCtrlBuf);
732 :
733 : // read one message
734 52 : ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(kSockFd, &msg, MSG_CMSG_CLOEXEC));
735 52 : if (rd < 0)
736 : {
737 15 : if (errno != EWOULDBLOCK)
738 : {
739 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
740 0 : disconnectClient(clientId);
741 : }
742 :
743 37 : break;
744 : }
745 37 : else if (rd == 0)
746 : {
747 : // client closed connection, and we've read all data, add to the condemned set
748 : // so is cleaned up once all the events are processed
749 22 : disconnectClient(clientId);
750 :
751 22 : break;
752 : }
753 15 : else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
754 : {
755 0 : RIALTO_IPC_LOG_WARN("received message from client %" PRIu64 " truncated, discarding", clientId);
756 :
757 : // make sure to close all the fds, otherwise we'll leak them
758 0 : readMessageFds(&msg, 16);
759 : }
760 : else
761 : {
762 : // if there is control data then assume fd(s) have been passed
763 15 : if (msg.msg_controllen > 0)
764 : {
765 0 : processClientMessage(clientObj, m_recvDataBuf, rd, readMessageFds(&msg, 16));
766 : }
767 : else
768 : {
769 15 : processClientMessage(clientObj, m_recvDataBuf, rd);
770 : }
771 : }
772 : }
773 : }
774 37 : }
775 :
776 : // -----------------------------------------------------------------------------
777 : /*!
778 : \internal
779 :
780 : Places the received file descriptors into the request message.
781 :
782 : It works by iterating over the fields in the message, finding ones that are
783 : marked as 'field_is_fd' and then replacing the received integer value with
784 : an actual file descriptor.
785 :
786 : */
787 15 : static bool addRequestFileDescriptors(google::protobuf::Message *request, const std::vector<FileDescriptor> &requestFds)
788 : {
789 15 : auto fdIterator = requestFds.begin();
790 :
791 15 : const google::protobuf::Descriptor *kDescriptor = request->GetDescriptor();
792 15 : const google::protobuf::Reflection *kReflection = nullptr;
793 :
794 15 : const int n = kDescriptor->field_count();
795 47 : for (int i = 0; i < n; i++)
796 : {
797 32 : auto fieldDescriptor = kDescriptor->field(i);
798 32 : if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
799 : {
800 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
801 : {
802 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
803 0 : return false;
804 : }
805 :
806 0 : if (!kReflection)
807 : {
808 0 : kReflection = request->GetReflection();
809 : }
810 :
811 0 : if (kReflection->HasField(*request, fieldDescriptor))
812 : {
813 0 : if (fdIterator == requestFds.end())
814 : {
815 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but one was supplied");
816 0 : return false;
817 : }
818 :
819 0 : kReflection->SetInt32(request, fieldDescriptor, fdIterator->fd());
820 0 : ++fdIterator;
821 : }
822 : }
823 : }
824 :
825 15 : if (fdIterator != requestFds.end())
826 : {
827 0 : RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
828 0 : return false;
829 : }
830 :
831 15 : return true;
832 : }
833 :
834 : // -----------------------------------------------------------------------------
835 : /*!
836 : \internal
837 :
838 : Processes a message received on a client socket.
839 :
840 : */
841 15 : void ServerImpl::processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
842 : const std::vector<FileDescriptor> &fds)
843 : {
844 15 : RIALTO_IPC_LOG_DEBUG("processing client message of size %zu bytes (%zu fds) from client %" PRId64, dataLen,
845 : fds.size(), client->id());
846 :
847 : // parse the message
848 15 : transport::MessageToServer message;
849 15 : if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
850 : {
851 0 : RIALTO_IPC_LOG_ERROR("invalid request");
852 0 : return;
853 : }
854 :
855 15 : if (message.has_call())
856 : {
857 15 : processMethodCall(client, message.call(), fds);
858 : }
859 : else
860 : {
861 0 : RIALTO_IPC_LOG_WARN("received unknown message type from client");
862 : }
863 15 : }
864 :
865 : // -----------------------------------------------------------------------------
866 : /*!
867 : \internal
868 :
869 : Processes a method call requst from a client.
870 :
871 : */
872 15 : void ServerImpl::processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
873 : const std::vector<FileDescriptor> &fds)
874 : {
875 : // try and find the service with the given name
876 15 : const std::string &kServiceName = call.service_name();
877 15 : auto it = client->m_services.find(kServiceName);
878 15 : if (it == client->m_services.end())
879 : {
880 0 : RIALTO_IPC_LOG_ERROR("unknown service request '%s'", kServiceName.c_str());
881 :
882 0 : sendErrorReply(client, call.serial_id(), "Unknown service '%s'", kServiceName.c_str());
883 0 : return;
884 : }
885 :
886 15 : std::shared_ptr<google::protobuf::Service> service = it->second;
887 :
888 : // try and find the method
889 15 : const std::string &kMethodName = call.method_name();
890 15 : const google::protobuf::MethodDescriptor *kMethod = service->GetDescriptor()->FindMethodByName(kMethodName);
891 15 : if (!kMethod)
892 : {
893 0 : RIALTO_IPC_LOG_ERROR("no method with name '%s'", kMethodName.c_str());
894 :
895 0 : sendErrorReply(client, call.serial_id(), "Unknown method '%s'", kMethodName.c_str());
896 0 : return;
897 : }
898 :
899 : // check if the method is expecting a reply
900 15 : const bool kNoReply = kMethod->options().HasExtension(no_reply) && kMethod->options().GetExtension(no_reply);
901 :
902 : // parse the request data
903 15 : google::protobuf::Message *requestMessage = service->GetRequestPrototype(kMethod).New();
904 15 : if (!requestMessage->ParseFromString(call.request_message()))
905 : {
906 0 : RIALTO_IPC_LOG_ERROR("failed to parse method from array");
907 : }
908 15 : else if (!addRequestFileDescriptors(requestMessage, fds))
909 : {
910 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the request");
911 : }
912 : else
913 : {
914 15 : RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", call.serial_id(), kServiceName.c_str(),
915 : kMethodName.c_str(), requestMessage->ShortDebugString().c_str());
916 :
917 15 : auto *controller = new ServerControllerImpl(client, call.serial_id());
918 :
919 15 : if (kNoReply)
920 : {
921 : // we should not send a reply for this call, so call the code to handle the
922 : // request, but no need to pass a controller, response or closure object
923 1 : static google::protobuf::internal::FunctionClosure0 nullClosure(&google::protobuf::DoNothing, false);
924 1 : service->CallMethod(kMethod, controller, requestMessage, nullptr, &nullClosure);
925 :
926 1 : delete controller;
927 : }
928 : else
929 : {
930 : // create a response
931 14 : google::protobuf::Message *responseMessage = service->GetResponsePrototype(kMethod).New();
932 :
933 : // this is finally where we call the service implementation to process the request
934 14 : service->CallMethod(kMethod, controller, requestMessage, responseMessage,
935 : google::protobuf::NewCallback(this, &ServerImpl::handleResponse, controller,
936 : responseMessage));
937 : }
938 : }
939 :
940 15 : delete requestMessage;
941 : }
942 :
943 : // -----------------------------------------------------------------------------
944 : /*!
945 : \internal
946 : \threadsafe
947 :
948 : Sends a message to the given client if still connected.
949 :
950 : */
951 14 : void ServerImpl::sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg)
952 : {
953 : // now take the lock (so the socket is not closed beneath us) and send the reply
954 14 : std::lock_guard<std::mutex> locker(m_clientsLock);
955 :
956 14 : auto it = m_clients.find(clientId);
957 14 : if (it == m_clients.end())
958 : {
959 1 : RIALTO_IPC_LOG_WARN("socket removed before error reply could be sent");
960 : }
961 13 : else if (it->second.sock < 0)
962 : {
963 0 : RIALTO_IPC_LOG_WARN("socket closed before error reply could be sent");
964 : }
965 13 : else if (!msg)
966 : {
967 0 : RIALTO_IPC_LOG_WARN("invalid msg to send on socket, ignoring");
968 : }
969 13 : else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, msg.get(), MSG_NOSIGNAL)) !=
970 13 : static_cast<ssize_t>(msg->msg_iov->iov_len))
971 : {
972 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete error reply message");
973 : }
974 14 : }
975 :
976 : // -----------------------------------------------------------------------------
977 : /*!
978 : \internal
979 :
980 : Sends an error back to the client with a reason string in \a format.
981 :
982 : */
983 0 : void ServerImpl::sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
984 : {
985 : // construct the error string
986 : va_list ap;
987 0 : va_start(ap, format);
988 : char reason[512];
989 0 : vsnprintf(reason, sizeof(reason), format, ap);
990 0 : va_end(ap);
991 :
992 : // construct the reply message
993 0 : auto msg = populateErrorReply(client, serialId, reason);
994 :
995 : // and send it
996 0 : sendReply(client->id(), msg);
997 : }
998 :
999 : // -----------------------------------------------------------------------------
1000 : /*!
1001 : \internal
1002 : \threadsafe
1003 :
1004 : Called once the service handler code has processed the request and
1005 : completed it. This may be called from a different thread if the service
1006 : implementation decided to off-load the request to a processing thread.
1007 :
1008 : */
1009 14 : void ServerImpl::handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response)
1010 : {
1011 14 : if (!controller || !controller->m_kClient)
1012 : {
1013 0 : RIALTO_IPC_LOG_ERROR("missing controller or attached client");
1014 0 : return;
1015 : }
1016 :
1017 14 : const std::shared_ptr<const ClientImpl> kClient = controller->m_kClient;
1018 14 : const uint64_t kClientId = kClient->id();
1019 :
1020 14 : std::shared_ptr<msghdr> message;
1021 14 : if (!controller->m_failed)
1022 : {
1023 8 : message = populateReply(kClient, controller->m_kSerialId, response);
1024 : }
1025 : else
1026 : {
1027 6 : message = populateErrorReply(kClient, controller->m_kSerialId, controller->m_failureReason);
1028 : }
1029 :
1030 : // no longer need the controller or the response objects
1031 14 : delete response;
1032 14 : delete controller;
1033 :
1034 : // send the reply message to the given client
1035 14 : sendReply(kClientId, message);
1036 : }
1037 :
1038 : // -----------------------------------------------------------------------------
1039 : /*!
1040 : \internal
1041 :
1042 : Gets all the file descriptors that are stored in the \a response message
1043 : and returns them as a vector of ints.
1044 :
1045 : It works by iterating over the fields in the message, finding ones that are
1046 : marked as 'field_is_fd' and inserting the fd value into control part of the
1047 : message. The actual integer sent in the data part of the message is set to
1048 : -1.
1049 :
1050 : */
1051 12 : static std::vector<int> getResponseFileDescriptors(google::protobuf::Message *response)
1052 : {
1053 12 : std::vector<int> fds;
1054 :
1055 : // process any file descriptors from the response message
1056 12 : const google::protobuf::Descriptor *kDescriptor = response->GetDescriptor();
1057 12 : const google::protobuf::Reflection *kReflection = nullptr;
1058 :
1059 12 : const int n = kDescriptor->field_count();
1060 25 : for (int i = 0; i < n; i++)
1061 : {
1062 13 : auto fieldDescriptor = kDescriptor->field(i);
1063 13 : if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
1064 : {
1065 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
1066 : {
1067 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
1068 0 : return {};
1069 : }
1070 :
1071 0 : if (!kReflection)
1072 : {
1073 0 : kReflection = response->GetReflection();
1074 : }
1075 :
1076 0 : if (kReflection->HasField(*response, fieldDescriptor))
1077 : {
1078 0 : fds.push_back(kReflection->GetInt32(*response, fieldDescriptor));
1079 0 : kReflection->SetInt32(response, fieldDescriptor, -1);
1080 : }
1081 : }
1082 : }
1083 :
1084 12 : return fds;
1085 : }
1086 :
1087 : // -----------------------------------------------------------------------------
1088 : /*!
1089 : \internal
1090 : \static
1091 :
1092 : Populates tha socket message buffer with the reply data for the RPC request.
1093 :
1094 : */
1095 8 : std::shared_ptr<msghdr> ServerImpl::populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
1096 : google::protobuf::Message *response)
1097 : {
1098 : // create the base reply
1099 8 : transport::MessageFromServer message;
1100 8 : transport::MethodCallReply *reply = message.mutable_reply();
1101 8 : if (!reply)
1102 : {
1103 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable reply object");
1104 0 : return nullptr;
1105 : }
1106 :
1107 : // convert the message response to a data string
1108 8 : std::string respString = response->SerializeAsString();
1109 :
1110 : // wrap in a transport response and send that
1111 8 : reply->set_reply_id(serialId);
1112 8 : reply->set_reply_message(std::move(respString));
1113 :
1114 : // next need to check if the response message has any file descriptors in
1115 : // it that need to be attached
1116 8 : const std::vector<int> kFds = getResponseFileDescriptors(response);
1117 8 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1118 :
1119 : // calculate the size of the reply
1120 8 : const size_t kRequiredDataLen = message.ByteSizeLong();
1121 8 : if (kRequiredDataLen > m_kMaxMessageLen)
1122 : {
1123 0 : RIALTO_IPC_LOG_ERROR("reply exceeds maximum message limit (%zu, max %zu)", kRequiredDataLen, m_kMaxMessageLen);
1124 :
1125 : // error message is too big, replace with a generic error
1126 0 : return populateErrorReply(client, serialId, "Internal error - reply message to large");
1127 : }
1128 :
1129 : // build the socket message to send
1130 : auto msgBuf =
1131 8 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
1132 :
1133 8 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1134 8 : bzero(header, sizeof(msghdr));
1135 :
1136 8 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1137 8 : header->msg_control = ctrl;
1138 8 : header->msg_controllen = kRequiredCtrlLen;
1139 :
1140 8 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1141 8 : header->msg_iov = iov;
1142 8 : header->msg_iovlen = 1;
1143 :
1144 8 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1145 8 : iov->iov_base = data;
1146 8 : iov->iov_len = kRequiredDataLen;
1147 :
1148 : // copy in the data
1149 8 : message.SerializeWithCachedSizesToArray(data);
1150 :
1151 : // add the fds
1152 8 : if (!kFds.empty())
1153 : {
1154 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1155 0 : if (!cmsg)
1156 : {
1157 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1158 0 : return nullptr;
1159 : }
1160 :
1161 0 : cmsg->cmsg_level = SOL_SOCKET;
1162 0 : cmsg->cmsg_type = SCM_RIGHTS;
1163 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1164 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1165 0 : header->msg_controllen = cmsg->cmsg_len;
1166 : }
1167 :
1168 8 : RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - { %s }", serialId, response->ShortDebugString().c_str());
1169 :
1170 : // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
1171 : // now do it manually
1172 8 : return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
1173 : }
1174 :
1175 : // -----------------------------------------------------------------------------
1176 : /*!
1177 : \internal
1178 :
1179 : Populates tha socket message buffer with a reply error message.
1180 :
1181 : */
1182 6 : std::shared_ptr<msghdr> ServerImpl::populateErrorReply(const std::shared_ptr<const ClientImpl> &client,
1183 : uint64_t serialId, const std::string &reason)
1184 : {
1185 : // create the base reply
1186 6 : transport::MessageFromServer message;
1187 6 : transport::MethodCallError *error = message.mutable_error();
1188 6 : error->set_reply_id(serialId);
1189 6 : error->set_error_reason(reason);
1190 :
1191 : // check the message will fit
1192 6 : size_t replySize = message.ByteSizeLong();
1193 6 : if (replySize > m_kMaxMessageLen)
1194 : {
1195 0 : RIALTO_IPC_LOG_ERROR("error reply exceeds max message size");
1196 :
1197 : // error message is to big, replace with a generic error
1198 0 : error->set_error_reason("Error message truncated");
1199 0 : replySize = message.ByteSizeLong();
1200 : }
1201 :
1202 6 : RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - \"%s\"", serialId, reason.c_str());
1203 :
1204 : // construct the message to send on the socket
1205 6 : auto msgBuf = m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + replySize);
1206 :
1207 6 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1208 6 : bzero(header, sizeof(msghdr));
1209 :
1210 6 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr));
1211 6 : header->msg_iov = iov;
1212 6 : header->msg_iovlen = 1;
1213 :
1214 6 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + sizeof(iovec));
1215 6 : iov->iov_base = data;
1216 6 : iov->iov_len = replySize;
1217 :
1218 : // serialise the reply to the message buffer
1219 6 : message.SerializeWithCachedSizesToArray(data);
1220 :
1221 : // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
1222 : // now do it manually
1223 12 : return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
1224 6 : }
1225 :
1226 : // -----------------------------------------------------------------------------
1227 : /*!
1228 : \threadsafe
1229 :
1230 : Returns \c true if the client with \a clientId is currently connected to
1231 : the server.
1232 :
1233 : */
1234 53 : bool ServerImpl::isClientConnected(uint64_t clientId) const
1235 : {
1236 53 : std::unique_lock<std::mutex> locker(m_clientsLock);
1237 106 : return (m_clients.count(clientId) > 0);
1238 53 : }
1239 :
1240 : // -----------------------------------------------------------------------------
1241 : /*!
1242 : \threadsafe
1243 :
1244 : May be called internally when there is an error on the socket, or externally
1245 : (possibly from a different thread) if the handler or service code decides to
1246 : close the client connection.
1247 :
1248 : */
1249 23 : void ServerImpl::disconnectClient(uint64_t clientId)
1250 : {
1251 23 : std::unique_lock<std::mutex> locker(m_clientsLock);
1252 23 : m_condemnedClients.insert(clientId);
1253 23 : locker.unlock();
1254 :
1255 23 : wakeEventLoop();
1256 : }
1257 :
1258 : // -----------------------------------------------------------------------------
1259 : /*!
1260 : \threadsafe
1261 :
1262 : Called via the IAVBusClient interface when an async event should be sent.
1263 :
1264 : This may be called from any thread, or from within the rpc message handler.
1265 :
1266 : The \a clientId is the client to send the event to.
1267 :
1268 : */
1269 4 : bool ServerImpl::sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &eventMessage)
1270 : {
1271 : // gets the file descriptors from the event message
1272 4 : const std::vector<int> kFds = getResponseFileDescriptors(eventMessage.get());
1273 4 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1274 :
1275 : // create the base reply
1276 4 : transport::MessageFromServer message;
1277 4 : transport::EventFromServer *event = message.mutable_event();
1278 4 : if (!event)
1279 : {
1280 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable event object");
1281 0 : return false;
1282 : }
1283 :
1284 4 : event->set_event_name(eventMessage->GetTypeName());
1285 :
1286 : // convert the event to a data string
1287 4 : std::string respString = eventMessage->SerializeAsString();
1288 :
1289 : // wrap in a transport response and send that
1290 4 : event->set_message(std::move(respString));
1291 :
1292 : // check the reply will fit
1293 4 : size_t requiredDataLen = message.ByteSizeLong();
1294 4 : if (requiredDataLen > m_kMaxMessageLen)
1295 : {
1296 0 : RIALTO_IPC_LOG_ERROR("event message to big to fit in buffer (size %zu, max size %zu)", requiredDataLen,
1297 : m_kMaxMessageLen);
1298 0 : return false;
1299 : }
1300 :
1301 : // build the socket message to send
1302 : auto msgBuf =
1303 4 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + requiredDataLen);
1304 :
1305 4 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1306 4 : bzero(header, sizeof(msghdr));
1307 :
1308 4 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1309 4 : header->msg_control = ctrl;
1310 4 : header->msg_controllen = kRequiredCtrlLen;
1311 :
1312 4 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1313 4 : header->msg_iov = iov;
1314 4 : header->msg_iovlen = 1;
1315 :
1316 4 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1317 4 : iov->iov_base = data;
1318 4 : iov->iov_len = requiredDataLen;
1319 :
1320 : // copy in the data
1321 4 : message.SerializeWithCachedSizesToArray(data);
1322 :
1323 : // add the fds
1324 4 : if (!kFds.empty())
1325 : {
1326 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1327 0 : if (!cmsg)
1328 : {
1329 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1330 0 : return false;
1331 : }
1332 :
1333 0 : cmsg->cmsg_level = SOL_SOCKET;
1334 0 : cmsg->cmsg_type = SCM_RIGHTS;
1335 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1336 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1337 0 : header->msg_controllen = cmsg->cmsg_len;
1338 : }
1339 :
1340 : // finally, take the lock (so the socket is not closed beneath us) and send the reply
1341 4 : std::unique_lock<std::mutex> locker(m_clientsLock);
1342 :
1343 4 : auto it = m_clients.find(clientId);
1344 4 : if (it == m_clients.end() || it->second.sock < 0)
1345 : {
1346 0 : RIALTO_IPC_LOG_WARN("socket closed before event could be sent");
1347 0 : return false;
1348 : }
1349 4 : else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, header, MSG_NOSIGNAL)) != static_cast<ssize_t>(requiredDataLen))
1350 : {
1351 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete event message");
1352 0 : return false;
1353 : }
1354 :
1355 4 : locker.unlock();
1356 :
1357 4 : RIALTO_IPC_LOG_DEBUG("event{ %s } - { %s }", eventMessage->GetTypeName().c_str(),
1358 : eventMessage->ShortDebugString().c_str());
1359 :
1360 4 : return true;
1361 : }
1362 : } // namespace firebolt::rialto::ipc
|