Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #include "IpcServerImpl.h"
21 : #include "IIpcServerFactory.h"
22 : #include "IpcClientImpl.h"
23 : #include "IpcLogging.h"
24 : #include "IpcServerControllerImpl.h"
25 :
26 : #include "rialtoipc.pb.h"
27 :
28 : #include <google/protobuf/service.h>
29 :
30 : #include <algorithm>
31 : #include <cinttypes>
32 : #include <cstdarg>
33 : #include <string>
34 : #include <utility>
35 : #include <vector>
36 :
37 : #include <fcntl.h>
38 : #include <poll.h>
39 : #include <sys/epoll.h>
40 : #include <sys/eventfd.h>
41 : #include <sys/file.h>
42 : #include <sys/socket.h>
43 : #include <sys/stat.h>
44 : #include <sys/un.h>
45 : #include <unistd.h>
46 :
47 : #define WAKE_EVENT_ID uint64_t(0)
48 : #define FIRST_LISTENING_SOCKET_ID uint64_t(1)
49 : #define FIRST_CLIENT_ID uint64_t(10000)
50 :
51 : namespace firebolt::rialto::ipc
52 : {
53 : const size_t ServerImpl::m_kMaxMessageLen = (128 * 1024);
54 :
55 32 : std::shared_ptr<IServerFactory> IServerFactory::createFactory()
56 : {
57 32 : std::shared_ptr<IServerFactory> factory;
58 : try
59 : {
60 32 : factory = std::make_shared<ServerFactory>();
61 : }
62 0 : catch (const std::exception &e)
63 : {
64 0 : RIALTO_IPC_LOG_ERROR("Failed to create the server factory, reason: %s", e.what());
65 : }
66 :
67 32 : return factory;
68 : }
69 :
70 32 : std::shared_ptr<IServer> ServerFactory::create()
71 : {
72 32 : return std::make_shared<ServerImpl>();
73 : }
74 :
75 32 : ServerImpl::ServerImpl()
76 32 : : m_pollFd(-1), m_wakeEventFd(-1), m_socketIdCounter(FIRST_LISTENING_SOCKET_ID), m_clientIdCounter(FIRST_CLIENT_ID),
77 64 : m_recvDataBuf{0}, m_recvCtrlBuf{0}
78 : {
79 : // create the eventfd use to wake the poll loop
80 32 : m_wakeEventFd = eventfd(0, EFD_CLOEXEC);
81 32 : if (m_wakeEventFd < 0)
82 : {
83 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd failed");
84 0 : return;
85 : }
86 :
87 : // create epoll loop
88 32 : m_pollFd = epoll_create1(EPOLL_CLOEXEC);
89 32 : if (m_pollFd < 0)
90 : {
91 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_create1 failed");
92 0 : return;
93 : }
94 :
95 : // add the wake event to epoll
96 32 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = WAKE_EVENT_ID}};
97 32 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, m_wakeEventFd, &event) != 0)
98 : {
99 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
100 : }
101 : }
102 :
103 32 : ServerImpl::~ServerImpl()
104 : {
105 32 : if ((m_pollFd >= 0) && (close(m_pollFd) != 0))
106 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close epoll");
107 :
108 32 : if ((m_wakeEventFd >= 0) && (close(m_wakeEventFd) != 0))
109 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close eventfd");
110 :
111 45 : for (const auto &entry : m_sockets)
112 : {
113 13 : const Socket &kSocket = entry.second;
114 :
115 13 : if (kSocket.isOwned)
116 : {
117 9 : if (unlink(kSocket.sockPath.c_str()) != 0)
118 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", kSocket.sockPath.c_str());
119 9 : if (close(kSocket.sockFd) != 0)
120 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
121 :
122 9 : if (unlink(kSocket.lockPath.c_str()) != 0)
123 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", kSocket.lockPath.c_str());
124 9 : if (close(kSocket.lockFd) != 0)
125 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
126 : }
127 : }
128 32 : }
129 :
130 : // -----------------------------------------------------------------------------
131 : /*!
132 : \static
133 : \internal
134 :
135 : Creates (if required) and takes the file lock associated with the socket
136 : path in the \a socket object.
137 :
138 : */
139 9 : bool ServerImpl::getSocketLock(Socket *socket)
140 : {
141 9 : std::string lockPath = socket->sockPath + ".lock";
142 9 : int fileDescriptor = open(lockPath.c_str(), O_CREAT | O_CLOEXEC, (S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP));
143 9 : if (fileDescriptor < 0)
144 : {
145 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create / open lockfile @ '%s' (check permissions)", lockPath.c_str());
146 0 : return false;
147 : }
148 :
149 9 : if (flock(fileDescriptor, LOCK_EX | LOCK_NB) < 0)
150 : {
151 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to lock lockfile @ '%s', maybe another server is running",
152 : lockPath.c_str());
153 0 : close(fileDescriptor);
154 0 : return false;
155 : }
156 :
157 9 : struct stat sbuf = {0};
158 9 : if (stat(socket->sockPath.c_str(), &sbuf) < 0)
159 : {
160 9 : if (errno != ENOENT)
161 : {
162 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "did not manage to stat existing socket @ '%s'", socket->sockPath.c_str());
163 0 : close(fileDescriptor);
164 0 : return false;
165 : }
166 : }
167 0 : else if ((sbuf.st_mode & S_IWUSR) || (sbuf.st_mode & S_IWGRP))
168 : {
169 0 : unlink(socket->sockPath.c_str());
170 : }
171 :
172 9 : socket->lockFd = fileDescriptor;
173 9 : socket->lockPath = std::move(lockPath);
174 :
175 9 : return true;
176 : }
177 :
178 : // -----------------------------------------------------------------------------
179 : /*!
180 : \static
181 : \internal
182 :
183 : Closes the file descriptors and the deletes the files stored in the \a socket
184 : object.
185 :
186 : */
187 0 : void ServerImpl::closeListeningSocket(Socket *socket)
188 : {
189 0 : if (!socket->isOwned)
190 : {
191 0 : return;
192 : }
193 :
194 0 : if (!socket->sockPath.empty() && (unlink(socket->sockPath.c_str()) != 0) && (errno != ENOENT))
195 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", socket->sockPath.c_str());
196 0 : if ((socket->sockFd >= 0) && (close(socket->sockFd) != 0))
197 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
198 :
199 0 : if (!socket->lockPath.empty() && (unlink(socket->lockPath.c_str()) != 0) && (errno != ENOENT))
200 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", socket->lockPath.c_str());
201 0 : if ((socket->lockFd >= 0) && (close(socket->lockFd) != 0))
202 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
203 :
204 0 : socket->sockFd = -1;
205 0 : socket->sockPath.clear();
206 :
207 0 : socket->lockFd = -1;
208 0 : socket->lockPath.clear();
209 : }
210 :
211 9 : bool ServerImpl::addSocket(const std::string &socketPath,
212 : const std::function<void(const std::shared_ptr<IClient> &)> &clientConnectedCb,
213 : const std::function<void(const std::shared_ptr<IClient> &)> &clientDisconnectedCb)
214 : {
215 : // store the path
216 9 : Socket socket;
217 9 : socket.sockPath = socketPath;
218 :
219 : // create the socket
220 9 : socket.sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
221 9 : if (socket.sockFd == -1)
222 : {
223 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
224 0 : return false;
225 : }
226 :
227 : // get the socket lock
228 9 : if (!getSocketLock(&socket))
229 : {
230 0 : closeListeningSocket(&socket);
231 0 : return false;
232 : }
233 :
234 : // bind to the given path
235 9 : struct sockaddr_un addr = {0};
236 9 : memset(&addr, 0x00, sizeof(addr));
237 9 : addr.sun_family = AF_UNIX;
238 9 : strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
239 :
240 9 : if (bind(socket.sockFd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
241 : {
242 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
243 :
244 0 : closeListeningSocket(&socket);
245 0 : return false;
246 : }
247 :
248 : // put in listening mode
249 9 : if (listen(socket.sockFd, 1) == -1)
250 : {
251 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
252 :
253 0 : closeListeningSocket(&socket);
254 0 : return false;
255 : }
256 :
257 : // create an id for the listening socket
258 9 : const uint64_t kSocketId = m_socketIdCounter++;
259 9 : if (kSocketId >= FIRST_CLIENT_ID)
260 : {
261 : // should never happen, we'd run out of file descriptors before
262 : // we hit the 10k limit on listening sockets
263 0 : RIALTO_IPC_LOG_ERROR("too many listening sockets");
264 :
265 0 : closeListeningSocket(&socket);
266 0 : return false;
267 : }
268 :
269 : // add the socket to epoll
270 9 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
271 9 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
272 : {
273 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
274 :
275 0 : closeListeningSocket(&socket);
276 0 : return false;
277 : }
278 :
279 : // store the client connected / disconnected callbacks
280 9 : socket.connectedCb = clientConnectedCb;
281 9 : socket.disconnectedCb = clientDisconnectedCb;
282 :
283 : // add to the internal map
284 9 : std::lock_guard<std::mutex> locker(m_socketsLock);
285 9 : m_sockets.emplace(kSocketId, std::move(socket));
286 :
287 9 : RIALTO_IPC_LOG_INFO("added listening socket '%s' to server", socketPath.c_str());
288 :
289 9 : return true;
290 : }
291 :
292 4 : bool ServerImpl::addSocket(int fd, const std::function<void(const std::shared_ptr<IClient> &)> &clientConnectedCb,
293 : const std::function<void(const std::shared_ptr<IClient> &)> &clientDisconnectedCb)
294 : {
295 : // store the path
296 4 : Socket socket;
297 4 : socket.isOwned = false;
298 4 : socket.sockFd = fd;
299 :
300 : // put in listening mode
301 4 : if (listen(socket.sockFd, 1) == -1)
302 : {
303 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
304 0 : return false;
305 : }
306 :
307 : // create an id for the listening socket
308 4 : const uint64_t kSocketId = m_socketIdCounter++;
309 4 : if (kSocketId >= FIRST_CLIENT_ID)
310 : {
311 : // should never happen, we'd run out of file descriptors before
312 : // we hit the 10k limit on listening sockets
313 0 : RIALTO_IPC_LOG_ERROR("too many listening sockets");
314 0 : return false;
315 : }
316 :
317 : // add the socket to epoll
318 4 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
319 4 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
320 : {
321 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
322 0 : return false;
323 : }
324 :
325 : // store the client connected / disconnected callbacks
326 4 : socket.connectedCb = clientConnectedCb;
327 4 : socket.disconnectedCb = clientDisconnectedCb;
328 :
329 : // add to the internal map
330 4 : std::lock_guard<std::mutex> locker(m_socketsLock);
331 4 : m_sockets.emplace(kSocketId, std::move(socket));
332 :
333 4 : RIALTO_IPC_LOG_INFO("added listening socket with fd: %d to server", fd);
334 :
335 4 : return true;
336 : }
337 :
338 16 : std::shared_ptr<IClient> ServerImpl::addClient(int socketFd,
339 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
340 : {
341 : // sanity check the supplied socket is of the right type
342 : struct sockaddr addr;
343 16 : socklen_t len = sizeof(sockaddr);
344 16 : if ((getsockname(socketFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
345 : {
346 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
347 0 : return nullptr;
348 : }
349 16 : if (addr.sa_family != AF_UNIX)
350 : {
351 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
352 0 : return nullptr;
353 : }
354 :
355 16 : int type = 0;
356 16 : len = sizeof(type);
357 16 : if ((getsockopt(socketFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
358 : {
359 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
360 0 : return nullptr;
361 : }
362 16 : if (type != SOCK_SEQPACKET)
363 : {
364 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
365 0 : return nullptr;
366 : }
367 :
368 : // dup the socket and set the O_CLOEXEC bit
369 16 : int duppedFd = fcntl(socketFd, F_DUPFD_CLOEXEC, 3);
370 16 : if (duppedFd < 0)
371 : {
372 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
373 0 : return nullptr;
374 : }
375 :
376 : // ensure the SOCK_NONBLOCK flag is set
377 16 : int flags = fcntl(duppedFd, F_GETFL);
378 16 : if ((flags < 0) || (fcntl(duppedFd, F_SETFL, flags | O_NONBLOCK) < 0))
379 : {
380 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
381 0 : close(duppedFd);
382 0 : return nullptr;
383 : }
384 :
385 : // finally, add the socket to the list of clients
386 48 : auto client = addClientSocket(duppedFd, "", std::move(clientDisconnectedCb));
387 16 : if (!client)
388 : {
389 0 : close(duppedFd);
390 : }
391 :
392 16 : return client;
393 : }
394 :
395 0 : int ServerImpl::fd() const
396 : {
397 0 : return m_pollFd;
398 : }
399 :
400 : // -----------------------------------------------------------------------------
401 : /*!
402 : \internal
403 :
404 : Writes to the eventfd to wake the event loop. Typically called when requesting
405 : it to shutdown or external code has requested that a client be disconnected.
406 :
407 : */
408 29 : void ServerImpl::wakeEventLoop() const
409 : {
410 29 : if (m_wakeEventFd < 0)
411 : {
412 0 : RIALTO_IPC_LOG_ERROR("invalid wake event fd");
413 : }
414 : else
415 : {
416 29 : uint64_t value = 1;
417 29 : if (TEMP_FAILURE_RETRY(::write(m_wakeEventFd, &value, sizeof(value))) != sizeof(value))
418 : {
419 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to the event fd");
420 : }
421 : }
422 29 : }
423 :
424 2891 : bool ServerImpl::wait(int timeoutMSecs)
425 : {
426 2891 : if (m_pollFd < 0)
427 : {
428 0 : return false;
429 : }
430 :
431 : // wait for any event (with timeout)
432 : struct pollfd fds[2];
433 2891 : fds[0].fd = m_pollFd;
434 2891 : fds[0].events = POLLIN;
435 :
436 2891 : int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
437 2891 : if (rc < 0)
438 : {
439 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
440 0 : return false;
441 : }
442 :
443 2891 : return true;
444 : }
445 :
446 2920 : bool ServerImpl::process()
447 : {
448 2920 : if (m_pollFd < 0)
449 : {
450 0 : RIALTO_IPC_LOG_ERROR("missing epoll");
451 0 : return false;
452 : }
453 :
454 : // read up to 32 events
455 2920 : const int kMaxEvents = 32;
456 : struct epoll_event events[kMaxEvents];
457 :
458 2920 : int rc = TEMP_FAILURE_RETRY(epoll_wait(m_pollFd, events, kMaxEvents, 0));
459 2920 : if (rc < 0)
460 : {
461 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
462 0 : return false;
463 : }
464 :
465 : // process the events (maybe 0 if timed out)
466 2985 : for (int i = 0; i < rc; i++)
467 : {
468 65 : const struct epoll_event &kEvent = events[i];
469 :
470 : // check if a wake event, in which case just clear the eventfd
471 65 : if (kEvent.data.u64 == WAKE_EVENT_ID)
472 : {
473 : uint64_t ignore;
474 3 : if (TEMP_FAILURE_RETRY(::read(m_wakeEventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
475 : {
476 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to read wake eventfd");
477 : }
478 : }
479 :
480 : // check for events on the listening socket
481 62 : else if (kEvent.data.u64 < FIRST_CLIENT_ID)
482 : {
483 13 : if (kEvent.events & EPOLLIN)
484 13 : processNewConnection(kEvent.data.u64);
485 13 : if (kEvent.events & EPOLLERR)
486 0 : RIALTO_IPC_LOG_ERROR("error occurred on listening socket");
487 : }
488 :
489 : // otherwise, the event must have come from a socket
490 : else
491 : {
492 49 : processClientSocket(kEvent.data.u64, kEvent.events);
493 : }
494 : }
495 :
496 : // if we have client sockets that are condemned then we need to shut down
497 : // and close them as well as remove from epoll
498 2920 : std::unique_lock<std::mutex> locker(m_clientsLock);
499 2920 : if (!m_condemnedClients.empty())
500 : {
501 : // take a copy of the set so we can process without the lock held
502 29 : std::set<uint64_t> theCondemned;
503 29 : m_condemnedClients.swap(theCondemned);
504 :
505 58 : for (uint64_t clientId : theCondemned)
506 : {
507 29 : auto it = m_clients.find(clientId);
508 29 : if (it == m_clients.end())
509 : {
510 0 : RIALTO_IPC_LOG_ERROR("failed to find condemned client");
511 0 : continue;
512 : }
513 :
514 29 : ClientDetails details = it->second;
515 :
516 : // remove from the list of clients
517 29 : m_clients.erase(it);
518 :
519 : // drop the lock while closing the connection and removing from epoll
520 29 : locker.unlock();
521 :
522 : // remove the socket from epoll and close it
523 29 : if (details.sock >= 0)
524 : {
525 29 : if (epoll_ctl(m_pollFd, EPOLL_CTL_DEL, details.sock, nullptr) != 0)
526 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket from epoll");
527 :
528 29 : if (shutdown(details.sock, SHUT_RDWR) != 0)
529 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to shutdown socket");
530 29 : if (close(details.sock) != 0)
531 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket");
532 : }
533 :
534 : // let the installed handler know a client has disconnected
535 29 : if (details.disconnectedCb)
536 29 : details.disconnectedCb(details.client);
537 :
538 : // ensure client object is destructed without the client lock held
539 29 : details.client.reset();
540 :
541 : // re-take the lock for the next client
542 29 : locker.lock();
543 : }
544 : }
545 :
546 2920 : return true;
547 : }
548 :
549 : // -----------------------------------------------------------------------------
550 : /*!
551 : \internal
552 :
553 : Adds the \a socketFd to the internal list of client sockets. This is called
554 : when a new connection is accepted on a listening socket, or when a client
555 : fd is added via ServerImpl::addClient(...).
556 :
557 : Returns a nullptr if failed to add the socket.
558 :
559 : */
560 29 : std::shared_ptr<ClientImpl> ServerImpl::addClientSocket(int socketFd, const std::string &listeningSocketPath,
561 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb)
562 : {
563 : // get the client credentials
564 29 : struct ucred clientCreds = {0};
565 29 : socklen_t clientCredsLen = sizeof(clientCreds);
566 :
567 58 : if ((getsockopt(socketFd, SOL_SOCKET, SO_PEERCRED, &clientCreds, &clientCredsLen) < 0) ||
568 29 : (clientCredsLen != sizeof(clientCreds)))
569 : {
570 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get client's details");
571 0 : return nullptr;
572 : }
573 :
574 : // create a new unique client id for the connection
575 29 : const uint64_t kClientId = m_clientIdCounter++;
576 :
577 : // add the new socket to the poll loop
578 29 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kClientId}};
579 29 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socketFd, &event) != 0)
580 : {
581 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add client socket");
582 0 : return nullptr;
583 : }
584 :
585 : // create initial client object for the socket
586 29 : auto client = std::make_shared<ClientImpl>(shared_from_this(), kClientId, clientCreds);
587 :
588 : // and the details for the internal list
589 29 : ClientDetails clientDetails;
590 29 : clientDetails.sock = socketFd;
591 29 : clientDetails.disconnectedCb = std::move(disconnectedCb);
592 29 : clientDetails.client = client;
593 :
594 : // add to the set of clients
595 : {
596 29 : std::lock_guard<std::mutex> locker(m_clientsLock);
597 29 : m_clients.emplace(kClientId, clientDetails);
598 : }
599 :
600 29 : RIALTO_IPC_LOG_INFO("new client connected - giving id %" PRIu64, kClientId);
601 :
602 29 : return client;
603 : }
604 :
605 : // -----------------------------------------------------------------------------
606 : /*!
607 : \internal
608 :
609 : Called when an event occurs on the listening socket. The code first accepts
610 : the connection and retrieves the client details, it then calls the installed
611 : handler to determine if we should drop this connection or not.
612 :
613 :
614 : */
615 13 : void ServerImpl::processNewConnection(uint64_t socketId)
616 : {
617 13 : RIALTO_IPC_LOG_DEBUG("processing new connection");
618 :
619 : int clientSock;
620 13 : std::string sockPath;
621 13 : std::function<void(const std::shared_ptr<IClient> &)> connectedCb;
622 13 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
623 :
624 : {
625 13 : std::unique_lock<std::mutex> socketLocker(m_socketsLock);
626 :
627 : // find matching socket object
628 13 : auto it = m_sockets.find(socketId);
629 13 : if (it == m_sockets.end())
630 : {
631 0 : RIALTO_IPC_LOG_ERROR("failed to find listening socket with id %" PRIu64, socketId);
632 0 : return;
633 : }
634 :
635 13 : const Socket &kSocket = it->second;
636 :
637 : // accept the connection from the client
638 13 : struct sockaddr clientAddr = {0};
639 13 : socklen_t clientAddrLen = sizeof(clientAddr);
640 :
641 13 : clientSock = accept4(kSocket.sockFd, &clientAddr, &clientAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC);
642 13 : if (clientSock < 0)
643 : {
644 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to accept client connection");
645 0 : return;
646 : }
647 :
648 13 : sockPath = kSocket.sockPath;
649 13 : connectedCb = kSocket.connectedCb;
650 13 : disconnectedCb = kSocket.disconnectedCb;
651 : }
652 :
653 : // attempt to add the socket to the client list
654 13 : auto client = addClientSocket(clientSock, sockPath, std::move(disconnectedCb));
655 13 : if (!client)
656 : {
657 0 : close(clientSock);
658 0 : return;
659 : }
660 :
661 : // notify the handler that a new connection has been made
662 13 : if (connectedCb)
663 : {
664 : // tell the handler we have a new client
665 13 : connectedCb(client);
666 : }
667 : }
668 :
669 : // -----------------------------------------------------------------------------
670 : /*!
671 : \internal
672 : \static
673 :
674 : Reads all the file descriptors from a message. It returns the file descriptors
675 : as a vector of FileDescriptor objects, these objects safely store the fd and
676 : close them when they're destructed.
677 :
678 : */
679 0 : static std::vector<FileDescriptor> readMessageFds(const struct msghdr *msg, size_t limit)
680 : {
681 0 : std::vector<FileDescriptor> fds;
682 :
683 0 : for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr;
684 0 : cmsg = CMSG_NXTHDR(const_cast<struct msghdr *>(msg), cmsg))
685 : {
686 0 : if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
687 : {
688 0 : const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
689 0 : if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
690 : {
691 0 : RIALTO_IPC_LOG_ERROR("invalid fd array size");
692 : }
693 : else
694 : {
695 0 : const size_t n = kFdsLength / sizeof(int);
696 0 : RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
697 :
698 0 : fds.reserve(std::min(limit, n));
699 :
700 0 : const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
701 0 : for (size_t i = 0; i < n; i++)
702 : {
703 0 : RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
704 :
705 0 : if (fds.size() >= limit)
706 : {
707 0 : RIALTO_IPC_LOG_ERROR("received to many file descriptors, "
708 : "exceeding max per message, closing left overs");
709 : }
710 : else
711 : {
712 0 : FileDescriptor fd(kFds[i]);
713 0 : if (!fd.isValid())
714 : {
715 0 : RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
716 : }
717 : else
718 : {
719 0 : fds.emplace_back(std::move(fd));
720 : }
721 : }
722 :
723 0 : if (close(kFds[i]) != 0)
724 : {
725 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
726 : }
727 : }
728 : }
729 : }
730 : }
731 :
732 0 : return fds;
733 : }
734 :
735 : // -----------------------------------------------------------------------------
736 : /*!
737 : \internal
738 :
739 : Processes an event from a client socket.
740 :
741 : */
742 49 : void ServerImpl::processClientSocket(uint64_t clientId, unsigned events)
743 : {
744 : int sockFd;
745 49 : std::shared_ptr<ClientImpl> clientObj;
746 :
747 : // take the lock while accessing the client list
748 : {
749 49 : std::unique_lock<std::mutex> locker(m_clientsLock);
750 :
751 49 : auto it = m_clients.find(clientId);
752 49 : if (it == m_clients.end())
753 : {
754 : // should never happen
755 0 : RIALTO_IPC_LOG_ERROR("received an event from a socket with no matching client");
756 0 : return;
757 : }
758 :
759 : // check if the client is marked for closure, if so then just ignore the data
760 49 : if (m_condemnedClients.count(clientId) != 0)
761 : {
762 0 : return;
763 : }
764 :
765 : // get the socket that corresponds to the client connection
766 49 : sockFd = it->second.sock;
767 :
768 : // get the client object
769 49 : clientObj = it->second.client;
770 : }
771 : // lock is released here automatically
772 :
773 : // if there was an error disconnect the socket
774 49 : if (events & EPOLLERR)
775 : {
776 0 : RIALTO_IPC_LOG_ERROR("error detected on client socket - disconnecting client");
777 0 : disconnectClient(clientId);
778 0 : return;
779 : }
780 :
781 49 : if (events & EPOLLIN)
782 : {
783 : // read all messages from the client socket, we break out if the socket is closed
784 : // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
785 : while (true)
786 : {
787 70 : struct msghdr msg = {nullptr};
788 70 : struct iovec io = {.iov_base = m_recvDataBuf, .iov_len = sizeof(m_recvDataBuf)};
789 :
790 70 : bzero(&msg, sizeof(msg));
791 70 : msg.msg_iov = &io;
792 70 : msg.msg_iovlen = 1;
793 70 : msg.msg_control = m_recvCtrlBuf;
794 70 : msg.msg_controllen = sizeof(m_recvCtrlBuf);
795 :
796 : // read one message
797 70 : ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(sockFd, &msg, MSG_CMSG_CLOEXEC));
798 70 : if (rd < 0)
799 : {
800 21 : if (errno != EWOULDBLOCK)
801 : {
802 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
803 0 : disconnectClient(clientId);
804 : }
805 :
806 49 : break;
807 : }
808 49 : else if (rd == 0)
809 : {
810 : // client closed connection, and we've read all data, add to the condemned set
811 : // so is cleaned up once all the events are processed
812 28 : disconnectClient(clientId);
813 :
814 28 : break;
815 : }
816 21 : else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
817 : {
818 0 : RIALTO_IPC_LOG_WARN("received message from client %" PRIu64 " truncated, discarding", clientId);
819 :
820 : // make sure to close all the fds, otherwise we'll leak them
821 0 : readMessageFds(&msg, 16);
822 : }
823 : else
824 : {
825 : // if there is control data then assume fd(s) have been passed
826 21 : if (msg.msg_controllen > 0)
827 : {
828 0 : processClientMessage(clientObj, m_recvDataBuf, rd, readMessageFds(&msg, 16));
829 : }
830 : else
831 : {
832 21 : processClientMessage(clientObj, m_recvDataBuf, rd);
833 : }
834 : }
835 : }
836 : }
837 49 : }
838 :
839 : // -----------------------------------------------------------------------------
840 : /*!
841 : \internal
842 :
843 : Places the received file descriptors into the request message.
844 :
845 : It works by iterating over the fields in the message, finding ones that are
846 : marked as 'field_is_fd' and then replacing the received integer value with
847 : an actual file descriptor.
848 :
849 : */
850 21 : static bool addRequestFileDescriptors(google::protobuf::Message *request, const std::vector<FileDescriptor> &requestFds)
851 : {
852 21 : auto fdIterator = requestFds.begin();
853 :
854 21 : const google::protobuf::Descriptor *kDescriptor = request->GetDescriptor();
855 21 : const google::protobuf::Reflection *kReflection = nullptr;
856 :
857 21 : const int n = kDescriptor->field_count();
858 79 : for (int i = 0; i < n; i++)
859 : {
860 58 : auto fieldDescriptor = kDescriptor->field(i);
861 58 : if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
862 : {
863 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
864 : {
865 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
866 0 : return false;
867 : }
868 :
869 0 : if (!kReflection)
870 : {
871 0 : kReflection = request->GetReflection();
872 : }
873 :
874 0 : if (kReflection->HasField(*request, fieldDescriptor))
875 : {
876 0 : if (fdIterator == requestFds.end())
877 : {
878 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but one was supplied");
879 0 : return false;
880 : }
881 :
882 0 : kReflection->SetInt32(request, fieldDescriptor, fdIterator->fd());
883 0 : ++fdIterator;
884 : }
885 : }
886 : }
887 :
888 21 : if (fdIterator != requestFds.end())
889 : {
890 0 : RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
891 0 : return false;
892 : }
893 :
894 21 : return true;
895 : }
896 :
897 : // -----------------------------------------------------------------------------
898 : /*!
899 : \internal
900 :
901 : Processes a message received on a client socket.
902 :
903 : */
904 21 : void ServerImpl::processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
905 : const std::vector<FileDescriptor> &fds)
906 : {
907 21 : RIALTO_IPC_LOG_DEBUG("processing client message of size %zu bytes (%zu fds) from client %" PRId64, dataLen,
908 : fds.size(), client->id());
909 :
910 : // parse the message
911 21 : transport::MessageToServer message;
912 21 : if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
913 : {
914 0 : RIALTO_IPC_LOG_ERROR("invalid request");
915 0 : return;
916 : }
917 :
918 21 : if (message.has_call())
919 : {
920 21 : processMethodCall(client, message.call(), fds);
921 : }
922 : else
923 : {
924 0 : RIALTO_IPC_LOG_WARN("received unknown message type from client");
925 : }
926 21 : }
927 :
928 : // -----------------------------------------------------------------------------
929 : /*!
930 : \internal
931 :
932 : Processes a method call requst from a client.
933 :
934 : */
935 21 : void ServerImpl::processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
936 : const std::vector<FileDescriptor> &fds)
937 : {
938 : // try and find the service with the given name
939 21 : const std::string &kServiceName = call.service_name();
940 21 : auto it = client->m_services.find(kServiceName);
941 21 : if (it == client->m_services.end())
942 : {
943 0 : RIALTO_IPC_LOG_ERROR("unknown service request '%s'", kServiceName.c_str());
944 :
945 0 : sendErrorReply(client, call.serial_id(), "Unknown service '%s'", kServiceName.c_str());
946 0 : return;
947 : }
948 :
949 21 : std::shared_ptr<google::protobuf::Service> service = it->second;
950 :
951 : // try and find the method
952 21 : const std::string &kMethodName = call.method_name();
953 21 : const google::protobuf::MethodDescriptor *kMethod = service->GetDescriptor()->FindMethodByName(kMethodName);
954 21 : if (!kMethod)
955 : {
956 0 : RIALTO_IPC_LOG_ERROR("no method with name '%s'", kMethodName.c_str());
957 :
958 0 : sendErrorReply(client, call.serial_id(), "Unknown method '%s'", kMethodName.c_str());
959 0 : return;
960 : }
961 :
962 : // check if the method is expecting a reply
963 21 : const bool kNoReply = kMethod->options().HasExtension(no_reply) && kMethod->options().GetExtension(no_reply);
964 :
965 : // parse the request data
966 21 : google::protobuf::Message *requestMessage = service->GetRequestPrototype(kMethod).New();
967 21 : if (!requestMessage->ParseFromString(call.request_message()))
968 : {
969 0 : RIALTO_IPC_LOG_ERROR("failed to parse method from array");
970 : }
971 21 : else if (!addRequestFileDescriptors(requestMessage, fds))
972 : {
973 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the request");
974 : }
975 : else
976 : {
977 21 : RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", call.serial_id(), kServiceName.c_str(),
978 : kMethodName.c_str(), requestMessage->ShortDebugString().c_str());
979 :
980 21 : auto *controller = new ServerControllerImpl(client, call.serial_id());
981 :
982 21 : if (kNoReply)
983 : {
984 : // we should not send a reply for this call, so call the code to handle the
985 : // request, but no need to pass a controller, response or closure object
986 1 : static google::protobuf::internal::FunctionClosure0 nullClosure(&google::protobuf::DoNothing, false);
987 1 : service->CallMethod(kMethod, controller, requestMessage, nullptr, &nullClosure);
988 :
989 1 : delete controller;
990 : }
991 : else
992 : {
993 : // create a response
994 20 : google::protobuf::Message *responseMessage = service->GetResponsePrototype(kMethod).New();
995 :
996 : // this is finally where we call the service implementation to process the request
997 20 : service->CallMethod(kMethod, controller, requestMessage, responseMessage,
998 : google::protobuf::NewCallback(this, &ServerImpl::handleResponse, controller,
999 : responseMessage));
1000 : }
1001 : }
1002 :
1003 21 : delete requestMessage;
1004 : }
1005 :
1006 : // -----------------------------------------------------------------------------
1007 : /*!
1008 : \internal
1009 : \threadsafe
1010 :
1011 : Sends a message to the given client if still connected.
1012 :
1013 : */
1014 20 : void ServerImpl::sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg)
1015 : {
1016 : // now take the lock (so the socket is not closed beneath us) and send the reply
1017 20 : std::lock_guard<std::mutex> locker(m_clientsLock);
1018 :
1019 20 : auto it = m_clients.find(clientId);
1020 20 : if (it == m_clients.end())
1021 : {
1022 1 : RIALTO_IPC_LOG_WARN("socket removed before error reply could be sent");
1023 : }
1024 19 : else if (it->second.sock < 0)
1025 : {
1026 0 : RIALTO_IPC_LOG_WARN("socket closed before error reply could be sent");
1027 : }
1028 19 : else if (!msg)
1029 : {
1030 0 : RIALTO_IPC_LOG_WARN("invalid msg to send on socket, ignoring");
1031 : }
1032 19 : else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, msg.get(), MSG_NOSIGNAL)) !=
1033 19 : static_cast<ssize_t>(msg->msg_iov->iov_len))
1034 : {
1035 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete error reply message");
1036 : }
1037 20 : }
1038 :
1039 : // -----------------------------------------------------------------------------
1040 : /*!
1041 : \internal
1042 :
1043 : Sends an error back to the client with a reason string in \a format.
1044 :
1045 : */
1046 0 : void ServerImpl::sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
1047 : {
1048 : // construct the error string
1049 : va_list ap;
1050 0 : va_start(ap, format);
1051 : char reason[512];
1052 0 : vsnprintf(reason, sizeof(reason), format, ap);
1053 0 : va_end(ap);
1054 :
1055 : // construct the reply message
1056 0 : auto msg = populateErrorReply(client, serialId, reason);
1057 :
1058 : // and send it
1059 0 : sendReply(client->id(), msg);
1060 : }
1061 :
1062 : // -----------------------------------------------------------------------------
1063 : /*!
1064 : \internal
1065 : \threadsafe
1066 :
1067 : Called once the service handler code has processed the request and
1068 : completed it. This may be called from a different thread if the service
1069 : implementation decided to off-load the request to a processing thread.
1070 :
1071 : */
1072 20 : void ServerImpl::handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response)
1073 : {
1074 20 : if (!controller || !controller->m_kClient)
1075 : {
1076 0 : RIALTO_IPC_LOG_ERROR("missing controller or attached client");
1077 0 : return;
1078 : }
1079 :
1080 20 : const std::shared_ptr<const ClientImpl> kClient = controller->m_kClient;
1081 20 : const uint64_t kClientId = kClient->id();
1082 :
1083 20 : std::shared_ptr<msghdr> message;
1084 20 : if (!controller->m_failed)
1085 : {
1086 13 : message = populateReply(kClient, controller->m_kSerialId, response);
1087 : }
1088 : else
1089 : {
1090 7 : message = populateErrorReply(kClient, controller->m_kSerialId, controller->m_failureReason);
1091 : }
1092 :
1093 : // no longer need the controller or the response objects
1094 20 : delete response;
1095 20 : delete controller;
1096 :
1097 : // send the reply message to the given client
1098 20 : sendReply(kClientId, message);
1099 : }
1100 :
1101 : // -----------------------------------------------------------------------------
1102 : /*!
1103 : \internal
1104 :
1105 : Gets all the file descriptors that are stored in the \a response message
1106 : and returns them as a vector of ints.
1107 :
1108 : It works by iterating over the fields in the message, finding ones that are
1109 : marked as 'field_is_fd' and inserting the fd value into control part of the
1110 : message. The actual integer sent in the data part of the message is set to
1111 : -1.
1112 :
1113 : */
1114 17 : static std::vector<int> getResponseFileDescriptors(google::protobuf::Message *response)
1115 : {
1116 17 : std::vector<int> fds;
1117 :
1118 : // process any file descriptors from the response message
1119 17 : const google::protobuf::Descriptor *kDescriptor = response->GetDescriptor();
1120 17 : const google::protobuf::Reflection *kReflection = nullptr;
1121 :
1122 17 : const int n = kDescriptor->field_count();
1123 34 : for (int i = 0; i < n; i++)
1124 : {
1125 17 : auto fieldDescriptor = kDescriptor->field(i);
1126 17 : if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
1127 : {
1128 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
1129 : {
1130 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
1131 0 : return {};
1132 : }
1133 :
1134 0 : if (!kReflection)
1135 : {
1136 0 : kReflection = response->GetReflection();
1137 : }
1138 :
1139 0 : if (kReflection->HasField(*response, fieldDescriptor))
1140 : {
1141 0 : fds.push_back(kReflection->GetInt32(*response, fieldDescriptor));
1142 0 : kReflection->SetInt32(response, fieldDescriptor, -1);
1143 : }
1144 : }
1145 : }
1146 :
1147 17 : return fds;
1148 : }
1149 :
1150 : // -----------------------------------------------------------------------------
1151 : /*!
1152 : \internal
1153 : \static
1154 :
1155 : Populates tha socket message buffer with the reply data for the RPC request.
1156 :
1157 : */
1158 13 : std::shared_ptr<msghdr> ServerImpl::populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
1159 : google::protobuf::Message *response)
1160 : {
1161 : // create the base reply
1162 13 : transport::MessageFromServer message;
1163 13 : transport::MethodCallReply *reply = message.mutable_reply();
1164 13 : if (!reply)
1165 : {
1166 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable reply object");
1167 0 : return nullptr;
1168 : }
1169 :
1170 : // convert the message response to a data string
1171 13 : std::string respString = response->SerializeAsString();
1172 :
1173 : // wrap in a transport response and send that
1174 13 : reply->set_reply_id(serialId);
1175 13 : reply->set_reply_message(std::move(respString));
1176 :
1177 : // next need to check if the response message has any file descriptors in
1178 : // it that need to be attached
1179 13 : const std::vector<int> kFds = getResponseFileDescriptors(response);
1180 13 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1181 :
1182 : // calculate the size of the reply
1183 13 : const size_t kRequiredDataLen = message.ByteSizeLong();
1184 13 : if (kRequiredDataLen > m_kMaxMessageLen)
1185 : {
1186 0 : RIALTO_IPC_LOG_ERROR("reply exceeds maximum message limit (%zu, max %zu)", kRequiredDataLen, m_kMaxMessageLen);
1187 :
1188 : // error message is too big, replace with a generic error
1189 0 : return populateErrorReply(client, serialId, "Internal error - reply message to large");
1190 : }
1191 :
1192 : // build the socket message to send
1193 : auto msgBuf =
1194 13 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
1195 :
1196 13 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1197 13 : bzero(header, sizeof(msghdr));
1198 :
1199 13 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1200 13 : header->msg_control = ctrl;
1201 13 : header->msg_controllen = kRequiredCtrlLen;
1202 :
1203 13 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1204 13 : header->msg_iov = iov;
1205 13 : header->msg_iovlen = 1;
1206 :
1207 13 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1208 13 : iov->iov_base = data;
1209 13 : iov->iov_len = kRequiredDataLen;
1210 :
1211 : // copy in the data
1212 13 : message.SerializeWithCachedSizesToArray(data);
1213 :
1214 : // add the fds
1215 13 : if (!kFds.empty())
1216 : {
1217 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1218 0 : if (!cmsg)
1219 : {
1220 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1221 0 : return nullptr;
1222 : }
1223 :
1224 0 : cmsg->cmsg_level = SOL_SOCKET;
1225 0 : cmsg->cmsg_type = SCM_RIGHTS;
1226 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1227 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1228 0 : header->msg_controllen = cmsg->cmsg_len;
1229 : }
1230 :
1231 13 : RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - { %s }", serialId, response->ShortDebugString().c_str());
1232 :
1233 : // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
1234 : // now do it manually
1235 13 : return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
1236 : }
1237 :
1238 : // -----------------------------------------------------------------------------
1239 : /*!
1240 : \internal
1241 :
1242 : Populates tha socket message buffer with a reply error message.
1243 :
1244 : */
1245 7 : std::shared_ptr<msghdr> ServerImpl::populateErrorReply(const std::shared_ptr<const ClientImpl> &client,
1246 : uint64_t serialId, const std::string &reason)
1247 : {
1248 : // create the base reply
1249 7 : transport::MessageFromServer message;
1250 7 : transport::MethodCallError *error = message.mutable_error();
1251 7 : error->set_reply_id(serialId);
1252 : error->set_error_reason(reason);
1253 :
1254 : // check the message will fit
1255 7 : size_t replySize = message.ByteSizeLong();
1256 7 : if (replySize > m_kMaxMessageLen)
1257 : {
1258 0 : RIALTO_IPC_LOG_ERROR("error reply exceeds max message size");
1259 :
1260 : // error message is to big, replace with a generic error
1261 : error->set_error_reason("Error message truncated");
1262 0 : replySize = message.ByteSizeLong();
1263 : }
1264 :
1265 7 : RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - \"%s\"", serialId, reason.c_str());
1266 :
1267 : // construct the message to send on the socket
1268 7 : auto msgBuf = m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + replySize);
1269 :
1270 7 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1271 7 : bzero(header, sizeof(msghdr));
1272 :
1273 7 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr));
1274 7 : header->msg_iov = iov;
1275 7 : header->msg_iovlen = 1;
1276 :
1277 7 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + sizeof(iovec));
1278 7 : iov->iov_base = data;
1279 7 : iov->iov_len = replySize;
1280 :
1281 : // serialise the reply to the message buffer
1282 7 : message.SerializeWithCachedSizesToArray(data);
1283 :
1284 : // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
1285 : // now do it manually
1286 14 : return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
1287 7 : }
1288 :
1289 : // -----------------------------------------------------------------------------
1290 : /*!
1291 : \threadsafe
1292 :
1293 : Returns \c true if the client with \a clientId is currently connected to
1294 : the server.
1295 :
1296 : */
1297 61 : bool ServerImpl::isClientConnected(uint64_t clientId) const
1298 : {
1299 61 : std::unique_lock<std::mutex> locker(m_clientsLock);
1300 122 : return (m_clients.count(clientId) > 0);
1301 61 : }
1302 :
1303 : // -----------------------------------------------------------------------------
1304 : /*!
1305 : \threadsafe
1306 :
1307 : May be called internally when there is an error on the socket, or externally
1308 : (possibly from a different thread) if the handler or service code decides to
1309 : close the client connection.
1310 :
1311 : */
1312 29 : void ServerImpl::disconnectClient(uint64_t clientId)
1313 : {
1314 : {
1315 29 : std::lock_guard<std::mutex> locker(m_clientsLock);
1316 29 : m_condemnedClients.insert(clientId);
1317 : }
1318 :
1319 29 : wakeEventLoop();
1320 : }
1321 :
1322 : // -----------------------------------------------------------------------------
1323 : /*!
1324 : \threadsafe
1325 :
1326 : Called via the IAVBusClient interface when an async event should be sent.
1327 :
1328 : This may be called from any thread, or from within the rpc message handler.
1329 :
1330 : The \a clientId is the client to send the event to.
1331 :
1332 : */
1333 4 : bool ServerImpl::sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &eventMessage)
1334 : {
1335 : // gets the file descriptors from the event message
1336 4 : const std::vector<int> kFds = getResponseFileDescriptors(eventMessage.get());
1337 4 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1338 :
1339 : // create the base reply
1340 4 : transport::MessageFromServer message;
1341 4 : transport::EventFromServer *event = message.mutable_event();
1342 4 : if (!event)
1343 : {
1344 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable event object");
1345 0 : return false;
1346 : }
1347 :
1348 8 : event->set_event_name(eventMessage->GetTypeName());
1349 :
1350 : // convert the event to a data string
1351 4 : std::string respString = eventMessage->SerializeAsString();
1352 :
1353 : // wrap in a transport response and send that
1354 4 : event->set_message(std::move(respString));
1355 :
1356 : // check the reply will fit
1357 4 : size_t requiredDataLen = message.ByteSizeLong();
1358 4 : if (requiredDataLen > m_kMaxMessageLen)
1359 : {
1360 0 : RIALTO_IPC_LOG_ERROR("event message to big to fit in buffer (size %zu, max size %zu)", requiredDataLen,
1361 : m_kMaxMessageLen);
1362 0 : return false;
1363 : }
1364 :
1365 : // build the socket message to send
1366 : auto msgBuf =
1367 4 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + requiredDataLen);
1368 :
1369 4 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1370 4 : bzero(header, sizeof(msghdr));
1371 :
1372 4 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1373 4 : header->msg_control = ctrl;
1374 4 : header->msg_controllen = kRequiredCtrlLen;
1375 :
1376 4 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1377 4 : header->msg_iov = iov;
1378 4 : header->msg_iovlen = 1;
1379 :
1380 4 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1381 4 : iov->iov_base = data;
1382 4 : iov->iov_len = requiredDataLen;
1383 :
1384 : // copy in the data
1385 4 : message.SerializeWithCachedSizesToArray(data);
1386 :
1387 : // add the fds
1388 4 : if (!kFds.empty())
1389 : {
1390 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1391 0 : if (!cmsg)
1392 : {
1393 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1394 0 : return false;
1395 : }
1396 :
1397 0 : cmsg->cmsg_level = SOL_SOCKET;
1398 0 : cmsg->cmsg_type = SCM_RIGHTS;
1399 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1400 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1401 0 : header->msg_controllen = cmsg->cmsg_len;
1402 : }
1403 :
1404 : // finally, take the lock (so the socket is not closed beneath us) and send the reply
1405 : {
1406 4 : std::unique_lock<std::mutex> locker(m_clientsLock);
1407 :
1408 4 : auto it = m_clients.find(clientId);
1409 4 : if (it == m_clients.end() || it->second.sock < 0)
1410 : {
1411 0 : RIALTO_IPC_LOG_WARN("socket closed before event could be sent");
1412 0 : return false;
1413 : }
1414 4 : else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, header, MSG_NOSIGNAL)) !=
1415 4 : static_cast<ssize_t>(requiredDataLen))
1416 : {
1417 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete event message");
1418 0 : return false;
1419 : }
1420 4 : }
1421 :
1422 4 : RIALTO_IPC_LOG_DEBUG("event{ %s } - { %s }", eventMessage->GetTypeName().c_str(),
1423 : eventMessage->ShortDebugString().c_str());
1424 :
1425 4 : return true;
1426 : }
1427 : } // namespace firebolt::rialto::ipc
|