Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #include "IpcServerImpl.h"
21 : #include "IIpcServerFactory.h"
22 : #include "IpcClientImpl.h"
23 : #include "IpcLogging.h"
24 : #include "IpcServerControllerImpl.h"
25 :
26 : #include "rialtoipc.pb.h"
27 :
28 : #include <google/protobuf/service.h>
29 :
30 : #include <algorithm>
31 : #include <cinttypes>
32 : #include <cstdarg>
33 : #include <string>
34 : #include <utility>
35 : #include <vector>
36 :
37 : #include <fcntl.h>
38 : #include <poll.h>
39 : #include <sys/epoll.h>
40 : #include <sys/eventfd.h>
41 : #include <sys/file.h>
42 : #include <sys/socket.h>
43 : #include <sys/stat.h>
44 : #include <sys/un.h>
45 : #include <unistd.h>
46 :
47 : #define WAKE_EVENT_ID uint64_t(0)
48 : #define FIRST_LISTENING_SOCKET_ID uint64_t(1)
49 : #define FIRST_CLIENT_ID uint64_t(10000)
50 :
51 : namespace firebolt::rialto::ipc
52 : {
53 : const size_t ServerImpl::m_kMaxMessageLen = (128 * 1024);
54 :
55 32 : std::shared_ptr<IServerFactory> IServerFactory::createFactory()
56 : {
57 32 : std::shared_ptr<IServerFactory> factory;
58 : try
59 : {
60 32 : factory = std::make_shared<ServerFactory>();
61 : }
62 0 : catch (const std::exception &e)
63 : {
64 0 : RIALTO_IPC_LOG_ERROR("Failed to create the server factory, reason: %s", e.what());
65 : }
66 :
67 32 : return factory;
68 : }
69 :
70 32 : std::shared_ptr<IServer> ServerFactory::create()
71 : {
72 32 : return std::make_shared<ServerImpl>();
73 : }
74 :
75 32 : ServerImpl::ServerImpl()
76 32 : : m_pollFd(-1), m_wakeEventFd(-1), m_socketIdCounter(FIRST_LISTENING_SOCKET_ID), m_clientIdCounter(FIRST_CLIENT_ID),
77 64 : m_recvDataBuf{0}, m_recvCtrlBuf{0}
78 : {
79 : // create the eventfd use to wake the poll loop
80 32 : m_wakeEventFd = eventfd(0, EFD_CLOEXEC);
81 32 : if (m_wakeEventFd < 0)
82 : {
83 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd failed");
84 0 : return;
85 : }
86 :
87 : // create epoll loop
88 32 : m_pollFd = epoll_create1(EPOLL_CLOEXEC);
89 32 : if (m_pollFd < 0)
90 : {
91 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_create1 failed");
92 0 : return;
93 : }
94 :
95 : // add the wake event to epoll
96 32 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = WAKE_EVENT_ID}};
97 32 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, m_wakeEventFd, &event) != 0)
98 : {
99 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
100 : }
101 : }
102 :
103 32 : ServerImpl::~ServerImpl()
104 : {
105 32 : if ((m_pollFd >= 0) && (close(m_pollFd) != 0))
106 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close epoll");
107 :
108 32 : if ((m_wakeEventFd >= 0) && (close(m_wakeEventFd) != 0))
109 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close eventfd");
110 :
111 45 : for (const auto &entry : m_sockets)
112 : {
113 13 : const Socket &kSocket = entry.second;
114 :
115 13 : if (kSocket.isOwned)
116 : {
117 9 : if (unlink(kSocket.sockPath.c_str()) != 0)
118 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", kSocket.sockPath.c_str());
119 9 : if (close(kSocket.sockFd) != 0)
120 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
121 :
122 9 : if (unlink(kSocket.lockPath.c_str()) != 0)
123 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", kSocket.lockPath.c_str());
124 9 : if (close(kSocket.lockFd) != 0)
125 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
126 : }
127 : }
128 32 : }
129 :
130 : // -----------------------------------------------------------------------------
131 : /*!
132 : \static
133 : \internal
134 :
135 : Creates (if required) and takes the file lock associated with the socket
136 : path in the \a socket object.
137 :
138 : */
139 9 : bool ServerImpl::getSocketLock(Socket *socket)
140 : {
141 9 : std::string lockPath = socket->sockPath + ".lock";
142 9 : int fileDescriptor = open(lockPath.c_str(), O_CREAT | O_CLOEXEC, (S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP));
143 9 : if (fileDescriptor < 0)
144 : {
145 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create / open lockfile @ '%s' (check permissions)", lockPath.c_str());
146 0 : return false;
147 : }
148 :
149 9 : if (flock(fileDescriptor, LOCK_EX | LOCK_NB) < 0)
150 : {
151 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to lock lockfile @ '%s', maybe another server is running",
152 : lockPath.c_str());
153 0 : close(fileDescriptor);
154 0 : return false;
155 : }
156 :
157 9 : struct stat sbuf = {0};
158 9 : if (stat(socket->sockPath.c_str(), &sbuf) < 0)
159 : {
160 9 : if (errno != ENOENT)
161 : {
162 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "did not manage to stat existing socket @ '%s'", socket->sockPath.c_str());
163 0 : close(fileDescriptor);
164 0 : return false;
165 : }
166 : }
167 0 : else if ((sbuf.st_mode & S_IWUSR) || (sbuf.st_mode & S_IWGRP))
168 : {
169 0 : unlink(socket->sockPath.c_str());
170 : }
171 :
172 9 : socket->lockFd = fileDescriptor;
173 9 : socket->lockPath = std::move(lockPath);
174 :
175 9 : return true;
176 : }
177 :
178 : // -----------------------------------------------------------------------------
179 : /*!
180 : \static
181 : \internal
182 :
183 : Closes the file descriptors and the deletes the files stored in the \a socket
184 : object.
185 :
186 : */
187 0 : void ServerImpl::closeListeningSocket(Socket *socket)
188 : {
189 0 : if (!socket->isOwned)
190 : {
191 0 : return;
192 : }
193 :
194 0 : if (!socket->sockPath.empty() && (unlink(socket->sockPath.c_str()) != 0) && (errno != ENOENT))
195 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket @ '%s'", socket->sockPath.c_str());
196 0 : if ((socket->sockFd >= 0) && (close(socket->sockFd) != 0))
197 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close listening socket");
198 :
199 0 : if (!socket->lockPath.empty() && (unlink(socket->lockPath.c_str()) != 0) && (errno != ENOENT))
200 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket lock file @ '%s'", socket->lockPath.c_str());
201 0 : if ((socket->lockFd >= 0) && (close(socket->lockFd) != 0))
202 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket lock file");
203 :
204 0 : socket->sockFd = -1;
205 0 : socket->sockPath.clear();
206 :
207 0 : socket->lockFd = -1;
208 0 : socket->lockPath.clear();
209 : }
210 :
211 9 : bool ServerImpl::addSocket(const std::string &socketPath,
212 : std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
213 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
214 : {
215 : // store the path
216 9 : Socket socket;
217 9 : socket.sockPath = socketPath;
218 :
219 : // create the socket
220 9 : socket.sockFd = ::socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
221 9 : if (socket.sockFd == -1)
222 : {
223 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "socket error");
224 0 : return false;
225 : }
226 :
227 : // get the socket lock
228 9 : if (!getSocketLock(&socket))
229 : {
230 0 : closeListeningSocket(&socket);
231 0 : return false;
232 : }
233 :
234 : // bind to the given path
235 9 : struct sockaddr_un addr = {0};
236 9 : memset(&addr, 0x00, sizeof(addr));
237 9 : addr.sun_family = AF_UNIX;
238 9 : strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
239 :
240 9 : if (bind(socket.sockFd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
241 : {
242 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "bind error");
243 :
244 0 : closeListeningSocket(&socket);
245 0 : return false;
246 : }
247 :
248 : // put in listening mode
249 9 : if (listen(socket.sockFd, 1) == -1)
250 : {
251 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
252 :
253 0 : closeListeningSocket(&socket);
254 0 : return false;
255 : }
256 :
257 : // create an id for the listening socket
258 9 : const uint64_t kSocketId = m_socketIdCounter++;
259 9 : if (kSocketId >= FIRST_CLIENT_ID)
260 : {
261 : // should never happen, we'd run out of file descriptors before
262 : // we hit the 10k limit on listening sockets
263 0 : RIALTO_IPC_LOG_ERROR("too many listening sockets");
264 :
265 0 : closeListeningSocket(&socket);
266 0 : return false;
267 : }
268 :
269 : // add the socket to epoll
270 9 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
271 9 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
272 : {
273 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
274 :
275 0 : closeListeningSocket(&socket);
276 0 : return false;
277 : }
278 :
279 : // store the client connected / disconnected callbacks
280 9 : socket.connectedCb = clientConnectedCb;
281 9 : socket.disconnectedCb = clientDisconnectedCb;
282 :
283 : // add to the internal map
284 9 : std::lock_guard<std::mutex> locker(m_socketsLock);
285 9 : m_sockets.emplace(kSocketId, std::move(socket));
286 :
287 9 : RIALTO_IPC_LOG_INFO("added listening socket '%s' to server", socketPath.c_str());
288 :
289 9 : return true;
290 : }
291 :
292 4 : bool ServerImpl::addSocket(int fd, std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
293 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
294 : {
295 : // store the path
296 4 : Socket socket;
297 4 : socket.isOwned = false;
298 4 : socket.sockFd = fd;
299 :
300 : // put in listening mode
301 4 : if (listen(socket.sockFd, 1) == -1)
302 : {
303 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "listen error");
304 0 : return false;
305 : }
306 :
307 : // create an id for the listening socket
308 4 : const uint64_t kSocketId = m_socketIdCounter++;
309 4 : if (kSocketId >= FIRST_CLIENT_ID)
310 : {
311 : // should never happen, we'd run out of file descriptors before
312 : // we hit the 10k limit on listening sockets
313 0 : RIALTO_IPC_LOG_ERROR("too many listening sockets");
314 0 : return false;
315 : }
316 :
317 : // add the socket to epoll
318 4 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kSocketId}};
319 4 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socket.sockFd, &event) != 0)
320 : {
321 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add listening socket");
322 0 : return false;
323 : }
324 :
325 : // store the client connected / disconnected callbacks
326 4 : socket.connectedCb = clientConnectedCb;
327 4 : socket.disconnectedCb = clientDisconnectedCb;
328 :
329 : // add to the internal map
330 4 : std::lock_guard<std::mutex> locker(m_socketsLock);
331 4 : m_sockets.emplace(kSocketId, std::move(socket));
332 :
333 4 : RIALTO_IPC_LOG_INFO("added listening socket with fd: %d to server", fd);
334 :
335 4 : return true;
336 : }
337 :
338 16 : std::shared_ptr<IClient> ServerImpl::addClient(int socketFd,
339 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb)
340 : {
341 : // sanity check the supplied socket is of the right type
342 : struct sockaddr addr;
343 16 : socklen_t len = sizeof(sockaddr);
344 16 : if ((getsockname(socketFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
345 : {
346 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
347 0 : return nullptr;
348 : }
349 16 : if (addr.sa_family != AF_UNIX)
350 : {
351 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
352 0 : return nullptr;
353 : }
354 :
355 16 : int type = 0;
356 16 : len = sizeof(type);
357 16 : if ((getsockopt(socketFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
358 : {
359 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
360 0 : return nullptr;
361 : }
362 16 : if (type != SOCK_SEQPACKET)
363 : {
364 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
365 0 : return nullptr;
366 : }
367 :
368 : // dup the socket and set the O_CLOEXEC bit
369 16 : int duppedFd = fcntl(socketFd, F_DUPFD_CLOEXEC, 3);
370 16 : if (duppedFd < 0)
371 : {
372 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
373 0 : return nullptr;
374 : }
375 :
376 : // ensure the SOCK_NONBLOCK flag is set
377 16 : int flags = fcntl(duppedFd, F_GETFL);
378 16 : if ((flags < 0) || (fcntl(duppedFd, F_SETFL, flags | O_NONBLOCK) < 0))
379 : {
380 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
381 0 : close(duppedFd);
382 0 : return nullptr;
383 : }
384 :
385 : // finally, add the socket to the list of clients
386 48 : auto client = addClientSocket(duppedFd, "", std::move(clientDisconnectedCb));
387 16 : if (!client)
388 : {
389 0 : close(duppedFd);
390 : }
391 :
392 16 : return client;
393 : }
394 :
395 0 : int ServerImpl::fd() const
396 : {
397 0 : return m_pollFd;
398 : }
399 :
400 : // -----------------------------------------------------------------------------
401 : /*!
402 : \internal
403 :
404 : Writes to the eventfd to wake the event loop. Typically called when requesting
405 : it to shutdown or external code has requested that a client be disconnected.
406 :
407 : */
408 29 : void ServerImpl::wakeEventLoop() const
409 : {
410 29 : if (m_wakeEventFd < 0)
411 : {
412 0 : RIALTO_IPC_LOG_ERROR("invalid wake event fd");
413 : }
414 : else
415 : {
416 29 : uint64_t value = 1;
417 29 : if (TEMP_FAILURE_RETRY(::write(m_wakeEventFd, &value, sizeof(value))) != sizeof(value))
418 : {
419 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to the event fd");
420 : }
421 : }
422 29 : }
423 :
424 2893 : bool ServerImpl::wait(int timeoutMSecs)
425 : {
426 2893 : if (m_pollFd < 0)
427 : {
428 0 : return false;
429 : }
430 :
431 : // wait for any event (with timeout)
432 : struct pollfd fds[2];
433 2893 : fds[0].fd = m_pollFd;
434 2893 : fds[0].events = POLLIN;
435 :
436 2893 : int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
437 2893 : if (rc < 0)
438 : {
439 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
440 0 : return false;
441 : }
442 :
443 2893 : return true;
444 : }
445 :
446 2922 : bool ServerImpl::process()
447 : {
448 2922 : if (m_pollFd < 0)
449 : {
450 0 : RIALTO_IPC_LOG_ERROR("missing epoll");
451 0 : return false;
452 : }
453 :
454 : // read up to 32 events
455 2922 : const int kMaxEvents = 32;
456 : struct epoll_event events[kMaxEvents];
457 :
458 2922 : int rc = TEMP_FAILURE_RETRY(epoll_wait(m_pollFd, events, kMaxEvents, 0));
459 2922 : if (rc < 0)
460 : {
461 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
462 0 : return false;
463 : }
464 :
465 : // process the events (maybe 0 if timed out)
466 2987 : for (int i = 0; i < rc; i++)
467 : {
468 65 : const struct epoll_event &kEvent = events[i];
469 :
470 : // check if a wake event, in which case just clear the eventfd
471 65 : if (kEvent.data.u64 == WAKE_EVENT_ID)
472 : {
473 : uint64_t ignore;
474 3 : if (TEMP_FAILURE_RETRY(::read(m_wakeEventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
475 : {
476 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to read wake eventfd");
477 : }
478 : }
479 :
480 : // check for events on the listening socket
481 62 : else if (kEvent.data.u64 < FIRST_CLIENT_ID)
482 : {
483 13 : if (kEvent.events & EPOLLIN)
484 13 : processNewConnection(kEvent.data.u64);
485 13 : if (kEvent.events & EPOLLERR)
486 0 : RIALTO_IPC_LOG_ERROR("error occurred on listening socket");
487 : }
488 :
489 : // otherwise, the event must have come from a socket
490 : else
491 : {
492 49 : processClientSocket(kEvent.data.u64, kEvent.events);
493 : }
494 : }
495 :
496 : // if we have client sockets that are condemned then we need to shut down
497 : // and close them as well as remove from epoll
498 2922 : std::unique_lock<std::mutex> locker(m_clientsLock);
499 2922 : if (!m_condemnedClients.empty())
500 : {
501 : // take a copy of the set so we can process without the lock held
502 29 : std::set<uint64_t> theCondemned;
503 29 : m_condemnedClients.swap(theCondemned);
504 :
505 58 : for (uint64_t clientId : theCondemned)
506 : {
507 29 : auto it = m_clients.find(clientId);
508 29 : if (it == m_clients.end())
509 : {
510 0 : RIALTO_IPC_LOG_ERROR("failed to find condemned client");
511 0 : continue;
512 : }
513 :
514 29 : ClientDetails details = it->second;
515 :
516 : // remove from the list of clients
517 29 : m_clients.erase(it);
518 :
519 : // drop the lock while closing the connection and removing from epoll
520 29 : locker.unlock();
521 :
522 : // remove the socket from epoll and close it
523 29 : if (details.sock >= 0)
524 : {
525 29 : if (epoll_ctl(m_pollFd, EPOLL_CTL_DEL, details.sock, nullptr) != 0)
526 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to remove socket from epoll");
527 :
528 29 : if (shutdown(details.sock, SHUT_RDWR) != 0)
529 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to shutdown socket");
530 29 : if (close(details.sock) != 0)
531 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close socket");
532 : }
533 :
534 : // let the installed handler know a client has disconnected
535 29 : if (details.disconnectedCb)
536 29 : details.disconnectedCb(details.client);
537 :
538 : // ensure client object is destructed without the client lock held
539 29 : details.client.reset();
540 :
541 : // re-take the lock for the next client
542 29 : locker.lock();
543 : }
544 : }
545 :
546 2922 : return true;
547 : }
548 :
549 : // -----------------------------------------------------------------------------
550 : /*!
551 : \internal
552 :
553 : Adds the \a socketFd to the internal list of client sockets. This is called
554 : when a new connection is accepted on a listening socket, or when a client
555 : fd is added via ServerImpl::addClient(...).
556 :
557 : Returns a nullptr if failed to add the socket.
558 :
559 : */
560 29 : std::shared_ptr<ClientImpl> ServerImpl::addClientSocket(int socketFd, const std::string &listeningSocketPath,
561 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb)
562 : {
563 : // get the client credentials
564 29 : struct ucred clientCreds = {0};
565 29 : socklen_t clientCredsLen = sizeof(clientCreds);
566 :
567 58 : if ((getsockopt(socketFd, SOL_SOCKET, SO_PEERCRED, &clientCreds, &clientCredsLen) < 0) ||
568 29 : (clientCredsLen != sizeof(clientCreds)))
569 : {
570 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get client's details");
571 0 : return nullptr;
572 : }
573 :
574 : // create a new unique client id for the connection
575 29 : const uint64_t kClientId = m_clientIdCounter++;
576 :
577 : // add the new socket to the poll loop
578 29 : epoll_event event = {.events = EPOLLIN, .data = {.u64 = kClientId}};
579 29 : if (epoll_ctl(m_pollFd, EPOLL_CTL_ADD, socketFd, &event) != 0)
580 : {
581 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add client socket");
582 0 : return nullptr;
583 : }
584 :
585 : // create initial client object for the socket
586 29 : auto client = std::make_shared<ClientImpl>(shared_from_this(), kClientId, clientCreds);
587 :
588 : // and the details for the internal list
589 29 : ClientDetails clientDetails;
590 29 : clientDetails.sock = socketFd;
591 29 : clientDetails.disconnectedCb = std::move(disconnectedCb);
592 29 : clientDetails.client = client;
593 :
594 : // add to the set of clients
595 : {
596 29 : std::lock_guard<std::mutex> locker(m_clientsLock);
597 29 : m_clients.emplace(kClientId, clientDetails);
598 : }
599 :
600 29 : RIALTO_IPC_LOG_INFO("new client connected - giving id %" PRIu64, kClientId);
601 :
602 29 : return client;
603 : }
604 :
605 : // -----------------------------------------------------------------------------
606 : /*!
607 : \internal
608 :
609 : Called when an event occurs on the listening socket. The code first accepts
610 : the connection and retrieves the client details, it then calls the installed
611 : handler to determine if we should drop this connection or not.
612 :
613 :
614 : */
615 13 : void ServerImpl::processNewConnection(uint64_t socketId)
616 : {
617 13 : RIALTO_IPC_LOG_DEBUG("processing new connection");
618 :
619 13 : std::unique_lock<std::mutex> socketLocker(m_socketsLock);
620 :
621 : // find matching socket object
622 13 : auto it = m_sockets.find(socketId);
623 13 : if (it == m_sockets.end())
624 : {
625 0 : RIALTO_IPC_LOG_ERROR("failed to find listening socket with id %" PRIu64, socketId);
626 0 : return;
627 : }
628 :
629 13 : const Socket &kSocket = it->second;
630 :
631 : // accept the connection from the client
632 13 : struct sockaddr clientAddr = {0};
633 13 : socklen_t clientAddrLen = sizeof(clientAddr);
634 :
635 13 : int clientSock = accept4(kSocket.sockFd, &clientAddr, &clientAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC);
636 13 : if (clientSock < 0)
637 : {
638 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to accept client connection");
639 0 : return;
640 : }
641 :
642 13 : const std::string kSockPath = kSocket.sockPath;
643 13 : std::function<void(const std::shared_ptr<IClient> &)> connectedCb = kSocket.connectedCb;
644 13 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb = kSocket.disconnectedCb;
645 :
646 13 : socketLocker.unlock();
647 :
648 : // attempt to add the socket to the client list
649 13 : auto client = addClientSocket(clientSock, kSockPath, std::move(disconnectedCb));
650 13 : if (!client)
651 : {
652 0 : close(clientSock);
653 0 : return;
654 : }
655 :
656 : // notify the handler that a new connection has been made
657 13 : if (connectedCb)
658 : {
659 : // tell the handler we have a new client
660 13 : connectedCb(client);
661 : }
662 : }
663 :
664 : // -----------------------------------------------------------------------------
665 : /*!
666 : \internal
667 : \static
668 :
669 : Reads all the file descriptors from a message. It returns the file descriptors
670 : as a vector of FileDescriptor objects, these objects safely store the fd and
671 : close them when they're destructed.
672 :
673 : */
674 0 : static std::vector<FileDescriptor> readMessageFds(const struct msghdr *msg, size_t limit)
675 : {
676 0 : std::vector<FileDescriptor> fds;
677 :
678 0 : for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr;
679 0 : cmsg = CMSG_NXTHDR(const_cast<struct msghdr *>(msg), cmsg))
680 : {
681 0 : if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
682 : {
683 0 : const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
684 0 : if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
685 : {
686 0 : RIALTO_IPC_LOG_ERROR("invalid fd array size");
687 : }
688 : else
689 : {
690 0 : const size_t n = kFdsLength / sizeof(int);
691 0 : RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
692 :
693 0 : fds.reserve(std::min(limit, n));
694 :
695 0 : const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
696 0 : for (size_t i = 0; i < n; i++)
697 : {
698 0 : RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
699 :
700 0 : if (fds.size() >= limit)
701 : {
702 0 : RIALTO_IPC_LOG_ERROR("received to many file descriptors, "
703 : "exceeding max per message, closing left overs");
704 : }
705 : else
706 : {
707 0 : FileDescriptor fd(kFds[i]);
708 0 : if (!fd.isValid())
709 : {
710 0 : RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
711 : }
712 : else
713 : {
714 0 : fds.emplace_back(std::move(fd));
715 : }
716 : }
717 :
718 0 : if (close(kFds[i]) != 0)
719 : {
720 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
721 : }
722 : }
723 : }
724 : }
725 : }
726 :
727 0 : return fds;
728 : }
729 :
730 : // -----------------------------------------------------------------------------
731 : /*!
732 : \internal
733 :
734 : Processes an event from a client socket.
735 :
736 : */
737 49 : void ServerImpl::processClientSocket(uint64_t clientId, unsigned events)
738 : {
739 : // take the lock while accessing the client list
740 49 : std::unique_lock<std::mutex> locker(m_clientsLock);
741 :
742 49 : auto it = m_clients.find(clientId);
743 49 : if (it == m_clients.end())
744 : {
745 : // should never happen
746 0 : RIALTO_IPC_LOG_ERROR("received an event from a socket with no matching client");
747 0 : return;
748 : }
749 :
750 : // check if the client is marked for closure, if so then just ignore the data
751 49 : if (m_condemnedClients.count(clientId) != 0)
752 : {
753 0 : return;
754 : }
755 :
756 : // get the socket that corresponds to the client connection
757 49 : const int kSockFd = it->second.sock;
758 :
759 : // get the client object
760 49 : std::shared_ptr<ClientImpl> clientObj = it->second.client;
761 :
762 : // can safely release the lock now we have the clientId and client object
763 49 : locker.unlock();
764 :
765 : // if there was an error disconnect the socket
766 49 : if (events & EPOLLERR)
767 : {
768 0 : RIALTO_IPC_LOG_ERROR("error detected on client socket - disconnecting client");
769 0 : disconnectClient(clientId);
770 0 : return;
771 : }
772 :
773 49 : if (events & EPOLLIN)
774 : {
775 : // read all messages from the client socket, we break out if the socket is closed
776 : // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
777 : while (true)
778 : {
779 70 : struct msghdr msg = {nullptr};
780 70 : struct iovec io = {.iov_base = m_recvDataBuf, .iov_len = sizeof(m_recvDataBuf)};
781 :
782 70 : bzero(&msg, sizeof(msg));
783 70 : msg.msg_iov = &io;
784 70 : msg.msg_iovlen = 1;
785 70 : msg.msg_control = m_recvCtrlBuf;
786 70 : msg.msg_controllen = sizeof(m_recvCtrlBuf);
787 :
788 : // read one message
789 70 : ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(kSockFd, &msg, MSG_CMSG_CLOEXEC));
790 70 : if (rd < 0)
791 : {
792 21 : if (errno != EWOULDBLOCK)
793 : {
794 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
795 0 : disconnectClient(clientId);
796 : }
797 :
798 49 : break;
799 : }
800 49 : else if (rd == 0)
801 : {
802 : // client closed connection, and we've read all data, add to the condemned set
803 : // so is cleaned up once all the events are processed
804 28 : disconnectClient(clientId);
805 :
806 28 : break;
807 : }
808 21 : else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
809 : {
810 0 : RIALTO_IPC_LOG_WARN("received message from client %" PRIu64 " truncated, discarding", clientId);
811 :
812 : // make sure to close all the fds, otherwise we'll leak them
813 0 : readMessageFds(&msg, 16);
814 : }
815 : else
816 : {
817 : // if there is control data then assume fd(s) have been passed
818 21 : if (msg.msg_controllen > 0)
819 : {
820 0 : processClientMessage(clientObj, m_recvDataBuf, rd, readMessageFds(&msg, 16));
821 : }
822 : else
823 : {
824 21 : processClientMessage(clientObj, m_recvDataBuf, rd);
825 : }
826 : }
827 : }
828 : }
829 49 : }
830 :
831 : // -----------------------------------------------------------------------------
832 : /*!
833 : \internal
834 :
835 : Places the received file descriptors into the request message.
836 :
837 : It works by iterating over the fields in the message, finding ones that are
838 : marked as 'field_is_fd' and then replacing the received integer value with
839 : an actual file descriptor.
840 :
841 : */
842 21 : static bool addRequestFileDescriptors(google::protobuf::Message *request, const std::vector<FileDescriptor> &requestFds)
843 : {
844 21 : auto fdIterator = requestFds.begin();
845 :
846 21 : const google::protobuf::Descriptor *kDescriptor = request->GetDescriptor();
847 21 : const google::protobuf::Reflection *kReflection = nullptr;
848 :
849 21 : const int n = kDescriptor->field_count();
850 75 : for (int i = 0; i < n; i++)
851 : {
852 54 : auto fieldDescriptor = kDescriptor->field(i);
853 54 : if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
854 : {
855 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
856 : {
857 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
858 0 : return false;
859 : }
860 :
861 0 : if (!kReflection)
862 : {
863 0 : kReflection = request->GetReflection();
864 : }
865 :
866 0 : if (kReflection->HasField(*request, fieldDescriptor))
867 : {
868 0 : if (fdIterator == requestFds.end())
869 : {
870 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but one was supplied");
871 0 : return false;
872 : }
873 :
874 0 : kReflection->SetInt32(request, fieldDescriptor, fdIterator->fd());
875 0 : ++fdIterator;
876 : }
877 : }
878 : }
879 :
880 21 : if (fdIterator != requestFds.end())
881 : {
882 0 : RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
883 0 : return false;
884 : }
885 :
886 21 : return true;
887 : }
888 :
889 : // -----------------------------------------------------------------------------
890 : /*!
891 : \internal
892 :
893 : Processes a message received on a client socket.
894 :
895 : */
896 21 : void ServerImpl::processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
897 : const std::vector<FileDescriptor> &fds)
898 : {
899 21 : RIALTO_IPC_LOG_DEBUG("processing client message of size %zu bytes (%zu fds) from client %" PRId64, dataLen,
900 : fds.size(), client->id());
901 :
902 : // parse the message
903 21 : transport::MessageToServer message;
904 21 : if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
905 : {
906 0 : RIALTO_IPC_LOG_ERROR("invalid request");
907 0 : return;
908 : }
909 :
910 21 : if (message.has_call())
911 : {
912 21 : processMethodCall(client, message.call(), fds);
913 : }
914 : else
915 : {
916 0 : RIALTO_IPC_LOG_WARN("received unknown message type from client");
917 : }
918 21 : }
919 :
920 : // -----------------------------------------------------------------------------
921 : /*!
922 : \internal
923 :
924 : Processes a method call requst from a client.
925 :
926 : */
927 21 : void ServerImpl::processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
928 : const std::vector<FileDescriptor> &fds)
929 : {
930 : // try and find the service with the given name
931 21 : const std::string &kServiceName = call.service_name();
932 21 : auto it = client->m_services.find(kServiceName);
933 21 : if (it == client->m_services.end())
934 : {
935 0 : RIALTO_IPC_LOG_ERROR("unknown service request '%s'", kServiceName.c_str());
936 :
937 0 : sendErrorReply(client, call.serial_id(), "Unknown service '%s'", kServiceName.c_str());
938 0 : return;
939 : }
940 :
941 21 : std::shared_ptr<google::protobuf::Service> service = it->second;
942 :
943 : // try and find the method
944 21 : const std::string &kMethodName = call.method_name();
945 21 : const google::protobuf::MethodDescriptor *kMethod = service->GetDescriptor()->FindMethodByName(kMethodName);
946 21 : if (!kMethod)
947 : {
948 0 : RIALTO_IPC_LOG_ERROR("no method with name '%s'", kMethodName.c_str());
949 :
950 0 : sendErrorReply(client, call.serial_id(), "Unknown method '%s'", kMethodName.c_str());
951 0 : return;
952 : }
953 :
954 : // check if the method is expecting a reply
955 21 : const bool kNoReply = kMethod->options().HasExtension(no_reply) && kMethod->options().GetExtension(no_reply);
956 :
957 : // parse the request data
958 21 : google::protobuf::Message *requestMessage = service->GetRequestPrototype(kMethod).New();
959 21 : if (!requestMessage->ParseFromString(call.request_message()))
960 : {
961 0 : RIALTO_IPC_LOG_ERROR("failed to parse method from array");
962 : }
963 21 : else if (!addRequestFileDescriptors(requestMessage, fds))
964 : {
965 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the request");
966 : }
967 : else
968 : {
969 21 : RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", call.serial_id(), kServiceName.c_str(),
970 : kMethodName.c_str(), requestMessage->ShortDebugString().c_str());
971 :
972 21 : auto *controller = new ServerControllerImpl(client, call.serial_id());
973 :
974 21 : if (kNoReply)
975 : {
976 : // we should not send a reply for this call, so call the code to handle the
977 : // request, but no need to pass a controller, response or closure object
978 1 : static google::protobuf::internal::FunctionClosure0 nullClosure(&google::protobuf::DoNothing, false);
979 1 : service->CallMethod(kMethod, controller, requestMessage, nullptr, &nullClosure);
980 :
981 1 : delete controller;
982 : }
983 : else
984 : {
985 : // create a response
986 20 : google::protobuf::Message *responseMessage = service->GetResponsePrototype(kMethod).New();
987 :
988 : // this is finally where we call the service implementation to process the request
989 20 : service->CallMethod(kMethod, controller, requestMessage, responseMessage,
990 : google::protobuf::NewCallback(this, &ServerImpl::handleResponse, controller,
991 : responseMessage));
992 : }
993 : }
994 :
995 21 : delete requestMessage;
996 : }
997 :
998 : // -----------------------------------------------------------------------------
999 : /*!
1000 : \internal
1001 : \threadsafe
1002 :
1003 : Sends a message to the given client if still connected.
1004 :
1005 : */
1006 20 : void ServerImpl::sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg)
1007 : {
1008 : // now take the lock (so the socket is not closed beneath us) and send the reply
1009 20 : std::lock_guard<std::mutex> locker(m_clientsLock);
1010 :
1011 20 : auto it = m_clients.find(clientId);
1012 20 : if (it == m_clients.end())
1013 : {
1014 1 : RIALTO_IPC_LOG_WARN("socket removed before error reply could be sent");
1015 : }
1016 19 : else if (it->second.sock < 0)
1017 : {
1018 0 : RIALTO_IPC_LOG_WARN("socket closed before error reply could be sent");
1019 : }
1020 19 : else if (!msg)
1021 : {
1022 0 : RIALTO_IPC_LOG_WARN("invalid msg to send on socket, ignoring");
1023 : }
1024 19 : else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, msg.get(), MSG_NOSIGNAL)) !=
1025 19 : static_cast<ssize_t>(msg->msg_iov->iov_len))
1026 : {
1027 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete error reply message");
1028 : }
1029 20 : }
1030 :
1031 : // -----------------------------------------------------------------------------
1032 : /*!
1033 : \internal
1034 :
1035 : Sends an error back to the client with a reason string in \a format.
1036 :
1037 : */
1038 0 : void ServerImpl::sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
1039 : {
1040 : // construct the error string
1041 : va_list ap;
1042 0 : va_start(ap, format);
1043 : char reason[512];
1044 0 : vsnprintf(reason, sizeof(reason), format, ap);
1045 0 : va_end(ap);
1046 :
1047 : // construct the reply message
1048 0 : auto msg = populateErrorReply(client, serialId, reason);
1049 :
1050 : // and send it
1051 0 : sendReply(client->id(), msg);
1052 : }
1053 :
1054 : // -----------------------------------------------------------------------------
1055 : /*!
1056 : \internal
1057 : \threadsafe
1058 :
1059 : Called once the service handler code has processed the request and
1060 : completed it. This may be called from a different thread if the service
1061 : implementation decided to off-load the request to a processing thread.
1062 :
1063 : */
1064 20 : void ServerImpl::handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response)
1065 : {
1066 20 : if (!controller || !controller->m_kClient)
1067 : {
1068 0 : RIALTO_IPC_LOG_ERROR("missing controller or attached client");
1069 0 : return;
1070 : }
1071 :
1072 20 : const std::shared_ptr<const ClientImpl> kClient = controller->m_kClient;
1073 20 : const uint64_t kClientId = kClient->id();
1074 :
1075 20 : std::shared_ptr<msghdr> message;
1076 20 : if (!controller->m_failed)
1077 : {
1078 13 : message = populateReply(kClient, controller->m_kSerialId, response);
1079 : }
1080 : else
1081 : {
1082 7 : message = populateErrorReply(kClient, controller->m_kSerialId, controller->m_failureReason);
1083 : }
1084 :
1085 : // no longer need the controller or the response objects
1086 20 : delete response;
1087 20 : delete controller;
1088 :
1089 : // send the reply message to the given client
1090 20 : sendReply(kClientId, message);
1091 : }
1092 :
1093 : // -----------------------------------------------------------------------------
1094 : /*!
1095 : \internal
1096 :
1097 : Gets all the file descriptors that are stored in the \a response message
1098 : and returns them as a vector of ints.
1099 :
1100 : It works by iterating over the fields in the message, finding ones that are
1101 : marked as 'field_is_fd' and inserting the fd value into control part of the
1102 : message. The actual integer sent in the data part of the message is set to
1103 : -1.
1104 :
1105 : */
1106 17 : static std::vector<int> getResponseFileDescriptors(google::protobuf::Message *response)
1107 : {
1108 17 : std::vector<int> fds;
1109 :
1110 : // process any file descriptors from the response message
1111 17 : const google::protobuf::Descriptor *kDescriptor = response->GetDescriptor();
1112 17 : const google::protobuf::Reflection *kReflection = nullptr;
1113 :
1114 17 : const int n = kDescriptor->field_count();
1115 34 : for (int i = 0; i < n; i++)
1116 : {
1117 17 : auto fieldDescriptor = kDescriptor->field(i);
1118 17 : if (fieldDescriptor->options().HasExtension(field_is_fd) && fieldDescriptor->options().GetExtension(field_is_fd))
1119 : {
1120 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
1121 : {
1122 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
1123 0 : return {};
1124 : }
1125 :
1126 0 : if (!kReflection)
1127 : {
1128 0 : kReflection = response->GetReflection();
1129 : }
1130 :
1131 0 : if (kReflection->HasField(*response, fieldDescriptor))
1132 : {
1133 0 : fds.push_back(kReflection->GetInt32(*response, fieldDescriptor));
1134 0 : kReflection->SetInt32(response, fieldDescriptor, -1);
1135 : }
1136 : }
1137 : }
1138 :
1139 17 : return fds;
1140 : }
1141 :
1142 : // -----------------------------------------------------------------------------
1143 : /*!
1144 : \internal
1145 : \static
1146 :
1147 : Populates tha socket message buffer with the reply data for the RPC request.
1148 :
1149 : */
1150 13 : std::shared_ptr<msghdr> ServerImpl::populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
1151 : google::protobuf::Message *response)
1152 : {
1153 : // create the base reply
1154 13 : transport::MessageFromServer message;
1155 13 : transport::MethodCallReply *reply = message.mutable_reply();
1156 13 : if (!reply)
1157 : {
1158 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable reply object");
1159 0 : return nullptr;
1160 : }
1161 :
1162 : // convert the message response to a data string
1163 13 : std::string respString = response->SerializeAsString();
1164 :
1165 : // wrap in a transport response and send that
1166 13 : reply->set_reply_id(serialId);
1167 13 : reply->set_reply_message(std::move(respString));
1168 :
1169 : // next need to check if the response message has any file descriptors in
1170 : // it that need to be attached
1171 13 : const std::vector<int> kFds = getResponseFileDescriptors(response);
1172 13 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1173 :
1174 : // calculate the size of the reply
1175 13 : const size_t kRequiredDataLen = message.ByteSizeLong();
1176 13 : if (kRequiredDataLen > m_kMaxMessageLen)
1177 : {
1178 0 : RIALTO_IPC_LOG_ERROR("reply exceeds maximum message limit (%zu, max %zu)", kRequiredDataLen, m_kMaxMessageLen);
1179 :
1180 : // error message is too big, replace with a generic error
1181 0 : return populateErrorReply(client, serialId, "Internal error - reply message to large");
1182 : }
1183 :
1184 : // build the socket message to send
1185 : auto msgBuf =
1186 13 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
1187 :
1188 13 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1189 13 : bzero(header, sizeof(msghdr));
1190 :
1191 13 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1192 13 : header->msg_control = ctrl;
1193 13 : header->msg_controllen = kRequiredCtrlLen;
1194 :
1195 13 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1196 13 : header->msg_iov = iov;
1197 13 : header->msg_iovlen = 1;
1198 :
1199 13 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1200 13 : iov->iov_base = data;
1201 13 : iov->iov_len = kRequiredDataLen;
1202 :
1203 : // copy in the data
1204 13 : message.SerializeWithCachedSizesToArray(data);
1205 :
1206 : // add the fds
1207 13 : if (!kFds.empty())
1208 : {
1209 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1210 0 : if (!cmsg)
1211 : {
1212 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1213 0 : return nullptr;
1214 : }
1215 :
1216 0 : cmsg->cmsg_level = SOL_SOCKET;
1217 0 : cmsg->cmsg_type = SCM_RIGHTS;
1218 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1219 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1220 0 : header->msg_controllen = cmsg->cmsg_len;
1221 : }
1222 :
1223 13 : RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - { %s }", serialId, response->ShortDebugString().c_str());
1224 :
1225 : // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
1226 : // now do it manually
1227 13 : return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
1228 : }
1229 :
1230 : // -----------------------------------------------------------------------------
1231 : /*!
1232 : \internal
1233 :
1234 : Populates tha socket message buffer with a reply error message.
1235 :
1236 : */
1237 7 : std::shared_ptr<msghdr> ServerImpl::populateErrorReply(const std::shared_ptr<const ClientImpl> &client,
1238 : uint64_t serialId, const std::string &reason)
1239 : {
1240 : // create the base reply
1241 7 : transport::MessageFromServer message;
1242 7 : transport::MethodCallError *error = message.mutable_error();
1243 7 : error->set_reply_id(serialId);
1244 : error->set_error_reason(reason);
1245 :
1246 : // check the message will fit
1247 7 : size_t replySize = message.ByteSizeLong();
1248 7 : if (replySize > m_kMaxMessageLen)
1249 : {
1250 0 : RIALTO_IPC_LOG_ERROR("error reply exceeds max message size");
1251 :
1252 : // error message is to big, replace with a generic error
1253 : error->set_error_reason("Error message truncated");
1254 0 : replySize = message.ByteSizeLong();
1255 : }
1256 :
1257 7 : RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - \"%s\"", serialId, reason.c_str());
1258 :
1259 : // construct the message to send on the socket
1260 7 : auto msgBuf = m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + replySize);
1261 :
1262 7 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1263 7 : bzero(header, sizeof(msghdr));
1264 :
1265 7 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr));
1266 7 : header->msg_iov = iov;
1267 7 : header->msg_iovlen = 1;
1268 :
1269 7 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + sizeof(iovec));
1270 7 : iov->iov_base = data;
1271 7 : iov->iov_len = replySize;
1272 :
1273 : // serialise the reply to the message buffer
1274 7 : message.SerializeWithCachedSizesToArray(data);
1275 :
1276 : // std::reinterpret_pointer_cast is only implemented in C++17 and newer, so for
1277 : // now do it manually
1278 14 : return std::shared_ptr<msghdr>(msgBuf, reinterpret_cast<msghdr *>(msgBuf.get()));
1279 7 : }
1280 :
1281 : // -----------------------------------------------------------------------------
1282 : /*!
1283 : \threadsafe
1284 :
1285 : Returns \c true if the client with \a clientId is currently connected to
1286 : the server.
1287 :
1288 : */
1289 61 : bool ServerImpl::isClientConnected(uint64_t clientId) const
1290 : {
1291 61 : std::unique_lock<std::mutex> locker(m_clientsLock);
1292 122 : return (m_clients.count(clientId) > 0);
1293 61 : }
1294 :
1295 : // -----------------------------------------------------------------------------
1296 : /*!
1297 : \threadsafe
1298 :
1299 : May be called internally when there is an error on the socket, or externally
1300 : (possibly from a different thread) if the handler or service code decides to
1301 : close the client connection.
1302 :
1303 : */
1304 29 : void ServerImpl::disconnectClient(uint64_t clientId)
1305 : {
1306 29 : std::unique_lock<std::mutex> locker(m_clientsLock);
1307 29 : m_condemnedClients.insert(clientId);
1308 29 : locker.unlock();
1309 :
1310 29 : wakeEventLoop();
1311 : }
1312 :
1313 : // -----------------------------------------------------------------------------
1314 : /*!
1315 : \threadsafe
1316 :
1317 : Called via the IAVBusClient interface when an async event should be sent.
1318 :
1319 : This may be called from any thread, or from within the rpc message handler.
1320 :
1321 : The \a clientId is the client to send the event to.
1322 :
1323 : */
1324 4 : bool ServerImpl::sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &eventMessage)
1325 : {
1326 : // gets the file descriptors from the event message
1327 4 : const std::vector<int> kFds = getResponseFileDescriptors(eventMessage.get());
1328 4 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1329 :
1330 : // create the base reply
1331 4 : transport::MessageFromServer message;
1332 4 : transport::EventFromServer *event = message.mutable_event();
1333 4 : if (!event)
1334 : {
1335 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable event object");
1336 0 : return false;
1337 : }
1338 :
1339 8 : event->set_event_name(eventMessage->GetTypeName());
1340 :
1341 : // convert the event to a data string
1342 4 : std::string respString = eventMessage->SerializeAsString();
1343 :
1344 : // wrap in a transport response and send that
1345 4 : event->set_message(std::move(respString));
1346 :
1347 : // check the reply will fit
1348 4 : size_t requiredDataLen = message.ByteSizeLong();
1349 4 : if (requiredDataLen > m_kMaxMessageLen)
1350 : {
1351 0 : RIALTO_IPC_LOG_ERROR("event message to big to fit in buffer (size %zu, max size %zu)", requiredDataLen,
1352 : m_kMaxMessageLen);
1353 0 : return false;
1354 : }
1355 :
1356 : // build the socket message to send
1357 : auto msgBuf =
1358 4 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + requiredDataLen);
1359 :
1360 4 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1361 4 : bzero(header, sizeof(msghdr));
1362 :
1363 4 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1364 4 : header->msg_control = ctrl;
1365 4 : header->msg_controllen = kRequiredCtrlLen;
1366 :
1367 4 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1368 4 : header->msg_iov = iov;
1369 4 : header->msg_iovlen = 1;
1370 :
1371 4 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1372 4 : iov->iov_base = data;
1373 4 : iov->iov_len = requiredDataLen;
1374 :
1375 : // copy in the data
1376 4 : message.SerializeWithCachedSizesToArray(data);
1377 :
1378 : // add the fds
1379 4 : if (!kFds.empty())
1380 : {
1381 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1382 0 : if (!cmsg)
1383 : {
1384 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1385 0 : return false;
1386 : }
1387 :
1388 0 : cmsg->cmsg_level = SOL_SOCKET;
1389 0 : cmsg->cmsg_type = SCM_RIGHTS;
1390 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1391 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1392 0 : header->msg_controllen = cmsg->cmsg_len;
1393 : }
1394 :
1395 : // finally, take the lock (so the socket is not closed beneath us) and send the reply
1396 4 : std::unique_lock<std::mutex> locker(m_clientsLock);
1397 :
1398 4 : auto it = m_clients.find(clientId);
1399 4 : if (it == m_clients.end() || it->second.sock < 0)
1400 : {
1401 0 : RIALTO_IPC_LOG_WARN("socket closed before event could be sent");
1402 0 : return false;
1403 : }
1404 4 : else if (TEMP_FAILURE_RETRY(sendmsg(it->second.sock, header, MSG_NOSIGNAL)) != static_cast<ssize_t>(requiredDataLen))
1405 : {
1406 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to send the complete event message");
1407 0 : return false;
1408 : }
1409 :
1410 4 : locker.unlock();
1411 :
1412 4 : RIALTO_IPC_LOG_DEBUG("event{ %s } - { %s }", eventMessage->GetTypeName().c_str(),
1413 : eventMessage->ShortDebugString().c_str());
1414 :
1415 4 : return true;
1416 : }
1417 : } // namespace firebolt::rialto::ipc
|