Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #include <algorithm>
21 : #include <cinttypes>
22 : #include <cstdarg>
23 : #include <memory>
24 : #include <utility>
25 :
26 : #include <fcntl.h>
27 : #include <poll.h>
28 : #include <sys/epoll.h>
29 : #include <sys/eventfd.h>
30 : #include <sys/socket.h>
31 : #include <sys/timerfd.h>
32 : #include <sys/un.h>
33 : #include <unistd.h>
34 :
35 : #include "IpcChannelImpl.h"
36 : #include "IpcLogging.h"
37 : #include "rialtoipc.pb.h"
38 :
39 : #if !defined(SCM_MAX_FD)
40 : #define SCM_MAX_FD 255
41 : #endif
42 :
43 : namespace
44 : {
45 : constexpr size_t kMaxMessageSize{128 * 1024};
46 : const std::chrono::milliseconds kDefaultIpcTimeout{3000};
47 :
48 33 : std::chrono::milliseconds getIpcTimeout()
49 : {
50 33 : const char *kCustomTimeout = getenv("RIALTO_CLIENT_IPC_TIMEOUT");
51 33 : std::chrono::milliseconds timeout{kDefaultIpcTimeout};
52 33 : if (kCustomTimeout)
53 : {
54 : try
55 : {
56 0 : timeout = std::chrono::milliseconds{std::stoull(kCustomTimeout)};
57 0 : RIALTO_IPC_LOG_INFO("Using custom Ipc timeout: %sms", kCustomTimeout);
58 : }
59 0 : catch (const std::exception &e)
60 : {
61 0 : RIALTO_IPC_LOG_ERROR("Custom Ipc timeout invalid, ignoring: %s", kCustomTimeout);
62 : }
63 : }
64 33 : return timeout;
65 : }
66 : } // namespace
67 :
68 : namespace firebolt::rialto::ipc
69 : {
70 :
71 33 : std::shared_ptr<IChannelFactory> IChannelFactory::createFactory()
72 : {
73 33 : std::shared_ptr<IChannelFactory> factory;
74 : try
75 : {
76 33 : factory = std::make_shared<ChannelFactory>();
77 : }
78 0 : catch (const std::exception &e)
79 : {
80 0 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel factory, reason: %s", e.what());
81 : }
82 :
83 33 : return factory;
84 : }
85 :
86 17 : std::shared_ptr<IChannel> ChannelFactory::createChannel(int sockFd)
87 : {
88 17 : std::shared_ptr<IChannel> channel;
89 : try
90 : {
91 17 : channel = std::make_shared<ChannelImpl>(sockFd);
92 : }
93 1 : catch (const std::exception &e)
94 : {
95 1 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel with socketFd %d, reason: %s", sockFd, e.what());
96 : }
97 :
98 17 : return channel;
99 : }
100 :
101 16 : std::shared_ptr<IChannel> ChannelFactory::createChannel(const std::string &socketPath)
102 : {
103 16 : std::shared_ptr<IChannel> channel;
104 : try
105 : {
106 16 : channel = std::make_shared<ChannelImpl>(socketPath);
107 : }
108 3 : catch (const std::exception &e)
109 : {
110 3 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel with socketPath %s, reason: %s", socketPath.c_str(),
111 : e.what());
112 : }
113 :
114 16 : return channel;
115 : }
116 :
117 17 : ChannelImpl::ChannelImpl(int sock)
118 17 : : m_sock(-1), m_epollFd(-1), m_timerFd(-1), m_eventFd(-1), m_serialCounter(1), m_timeout(getIpcTimeout()),
119 34 : m_eventTagCounter(1)
120 : {
121 17 : if (!attachSocket(sock))
122 : {
123 1 : throw std::runtime_error("Failed attach the socket");
124 : }
125 16 : if (!initChannel())
126 : {
127 0 : termChannel();
128 0 : throw std::runtime_error("Failed to initalise the channel");
129 : }
130 16 : if (!isConnectedInternal())
131 : {
132 0 : termChannel();
133 0 : throw std::runtime_error("Channel not connected");
134 : }
135 20 : }
136 :
137 16 : ChannelImpl::ChannelImpl(const std::string &socketPath)
138 16 : : m_sock(-1), m_epollFd(-1), m_timerFd(-1), m_eventFd(-1), m_serialCounter(1), m_timeout(getIpcTimeout()),
139 32 : m_eventTagCounter(1)
140 : {
141 16 : if (!createConnectedSocket(socketPath))
142 : {
143 3 : throw std::runtime_error("Failed connect socket");
144 : }
145 13 : if (!initChannel())
146 : {
147 0 : termChannel();
148 0 : throw std::runtime_error("Failed to initalise the channel");
149 : }
150 13 : if (!isConnectedInternal())
151 : {
152 0 : termChannel();
153 0 : throw std::runtime_error("Channel not connected");
154 : }
155 25 : }
156 :
157 29 : ChannelImpl::~ChannelImpl()
158 : {
159 29 : termChannel();
160 : }
161 :
162 16 : bool ChannelImpl::createConnectedSocket(const std::string &socketPath)
163 : {
164 16 : int sock = socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
165 16 : if (sock < 0)
166 : {
167 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create socket");
168 0 : return false;
169 : }
170 :
171 16 : struct sockaddr_un addr = {0};
172 16 : memset(&addr, 0x00, sizeof(addr));
173 16 : addr.sun_family = AF_UNIX;
174 16 : strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
175 :
176 16 : if (::connect(sock, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
177 : {
178 3 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to connect to %s", socketPath.c_str());
179 3 : close(sock);
180 3 : return false;
181 : }
182 :
183 13 : m_sock = sock;
184 :
185 13 : return true;
186 : }
187 :
188 17 : bool ChannelImpl::attachSocket(int sockFd)
189 : {
190 : // sanity check the supplied socket is of the right type
191 : struct sockaddr addr;
192 17 : socklen_t len = sizeof(addr);
193 17 : if ((getsockname(sockFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
194 : {
195 1 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
196 1 : return false;
197 : }
198 16 : if (addr.sa_family != AF_UNIX)
199 : {
200 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
201 0 : return false;
202 : }
203 :
204 16 : int type = 0;
205 16 : len = sizeof(type);
206 16 : if ((getsockopt(sockFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
207 : {
208 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
209 0 : return false;
210 : }
211 16 : if (type != SOCK_SEQPACKET)
212 : {
213 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
214 0 : return false;
215 : }
216 :
217 : // set the O_NONBLOCKING flag on the socket
218 16 : int flags = fcntl(sockFd, F_GETFL);
219 16 : if ((flags < 0) || (fcntl(sockFd, F_SETFL, flags | O_NONBLOCK) < 0))
220 : {
221 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
222 0 : return false;
223 : }
224 :
225 16 : m_sock = sockFd;
226 :
227 16 : return true;
228 : }
229 :
230 29 : bool ChannelImpl::initChannel()
231 : {
232 : // create epoll so can listen for timeouts as well as socket messages
233 29 : m_epollFd = epoll_create1(EPOLL_CLOEXEC);
234 29 : if (m_epollFd < 0)
235 : {
236 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll failed");
237 0 : return false;
238 : }
239 :
240 : // add the socket to epoll
241 29 : epoll_event sockEvent = {.events = EPOLLIN, .data = {.fd = m_sock}};
242 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_sock, &sockEvent) != 0)
243 : {
244 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
245 0 : return false;
246 : }
247 :
248 29 : m_timerFd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC | TFD_NONBLOCK);
249 29 : if (m_timerFd < 0)
250 : {
251 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "timerfd_create failed");
252 0 : return false;
253 : }
254 :
255 : // add the timer event to epoll
256 29 : epoll_event timerEvent = {.events = EPOLLIN, .data = {.fd = m_timerFd}};
257 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_timerFd, &timerEvent) != 0)
258 : {
259 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
260 0 : return false;
261 : }
262 :
263 : // and lastly the eventfd to wake the poll loop
264 29 : m_eventFd = eventfd(0, EFD_CLOEXEC);
265 29 : if (m_eventFd < 0)
266 : {
267 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd create failed");
268 0 : return false;
269 : }
270 :
271 : // add the timer event to epoll
272 29 : epoll_event wakeEvent = {.events = EPOLLIN, .data = {.fd = m_eventFd}};
273 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_eventFd, &wakeEvent) != 0)
274 : {
275 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
276 0 : return false;
277 : }
278 :
279 29 : return true;
280 : }
281 :
282 29 : void ChannelImpl::termChannel()
283 : {
284 : // close the socket and the epoll and timer fds
285 29 : if (m_sock >= 0)
286 0 : disconnectNoLock();
287 29 : if ((m_epollFd >= 0) && (close(m_epollFd) != 0))
288 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing epoll fd failed");
289 29 : if ((m_timerFd >= 0) && (close(m_timerFd) != 0))
290 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing timer fd failed");
291 29 : if ((m_eventFd >= 0) && (close(m_eventFd) != 0))
292 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing event fd failed");
293 :
294 : // if any method calls are still outstanding then complete them with errors now
295 29 : for (auto &entry : m_methodCalls)
296 : {
297 0 : completeWithError(&entry.second, "Channel destructed");
298 : }
299 :
300 29 : m_methodCalls.clear();
301 : }
302 :
303 29 : void ChannelImpl::disconnect()
304 : {
305 : // disconnect from the socket
306 : {
307 29 : std::lock_guard<std::mutex> locker(m_lock);
308 :
309 29 : if (m_sock < 0)
310 1 : return;
311 :
312 28 : disconnectNoLock();
313 16 : }
314 :
315 : // wake the wait(...) call so client code is blocked there it is woken
316 28 : if (m_eventFd >= 0)
317 : {
318 28 : uint64_t wakeup = 1;
319 28 : if (TEMP_FAILURE_RETRY(write(m_eventFd, &wakeup, sizeof(wakeup))) != sizeof(wakeup))
320 : {
321 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to wake event fd");
322 : }
323 : }
324 : }
325 :
326 : // -----------------------------------------------------------------------------
327 : /*!
328 : \internal
329 :
330 :
331 : */
332 29 : void ChannelImpl::disconnectNoLock()
333 : {
334 29 : if (m_sock < 0)
335 : {
336 0 : RIALTO_IPC_LOG_WARN("not connected\n");
337 0 : return;
338 : }
339 :
340 : // remove the socket from epoll
341 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_DEL, m_sock, nullptr) != 0)
342 : {
343 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to remove socket");
344 : }
345 :
346 : // shutdown and close the socket
347 29 : if (shutdown(m_sock, SHUT_RDWR) != 0)
348 : {
349 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "shutdown error");
350 : }
351 29 : if (close(m_sock) != 0)
352 : {
353 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "close error");
354 : }
355 :
356 29 : m_sock = -1;
357 : }
358 :
359 12695 : bool ChannelImpl::isConnected() const
360 : {
361 12695 : return isConnectedInternal();
362 : }
363 :
364 12724 : bool ChannelImpl::isConnectedInternal() const
365 : {
366 12724 : std::lock_guard<std::mutex> locker(m_lock);
367 12724 : return (m_sock >= 0);
368 : }
369 :
370 0 : int ChannelImpl::fd() const
371 : {
372 0 : return m_epollFd;
373 : }
374 :
375 3157 : bool ChannelImpl::wait(int timeoutMSecs)
376 : {
377 3157 : if ((m_epollFd < 0) || !isConnected())
378 : {
379 0 : return false;
380 : }
381 :
382 : // wait for any event (with timeout)
383 : struct pollfd fds[2];
384 3157 : fds[0].fd = m_epollFd;
385 3157 : fds[0].events = POLLIN;
386 :
387 3157 : int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
388 3157 : if (rc < 0)
389 : {
390 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
391 0 : return false;
392 : }
393 :
394 3157 : return isConnected();
395 : }
396 :
397 3187 : bool ChannelImpl::process()
398 : {
399 3187 : if (!isConnected())
400 17 : return false;
401 :
402 : struct epoll_event events[3];
403 3170 : int rc = TEMP_FAILURE_RETRY(epoll_wait(m_epollFd, events, 3, 0));
404 3170 : if (rc < 0)
405 : {
406 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
407 0 : return false;
408 : }
409 :
410 : enum
411 : {
412 : HaveSocketEvent = 0x1,
413 : HaveTimeoutEvent = 0x2,
414 : HaveWakeEvent = 0x4
415 : };
416 3170 : unsigned eventsMask = 0;
417 3196 : for (int i = 0; i < rc; i++)
418 : {
419 26 : if (events[i].data.fd == m_sock)
420 24 : eventsMask |= HaveSocketEvent;
421 2 : else if (events[i].data.fd == m_timerFd)
422 2 : eventsMask |= HaveTimeoutEvent;
423 0 : else if (events[i].data.fd == m_eventFd)
424 0 : eventsMask |= HaveWakeEvent;
425 : }
426 :
427 3170 : if ((eventsMask & HaveSocketEvent) && !processSocketEvent())
428 : {
429 1 : std::map<uint64_t, MethodCall> callsToDelete;
430 : {
431 1 : std::lock_guard<std::mutex> locker(m_lock);
432 1 : callsToDelete = m_methodCalls;
433 1 : m_methodCalls.clear();
434 : }
435 :
436 1 : for (auto &entry : callsToDelete)
437 : {
438 0 : completeWithError(&entry.second, "Socket down");
439 : }
440 1 : return false;
441 : }
442 :
443 3169 : if (eventsMask & HaveTimeoutEvent)
444 2 : processTimeoutEvent();
445 :
446 3169 : if (eventsMask & HaveWakeEvent)
447 0 : processWakeEvent();
448 :
449 3169 : return isConnected();
450 : }
451 :
452 56 : bool ChannelImpl::unsubscribe(int eventTag)
453 : {
454 56 : std::lock_guard<std::mutex> locker(m_eventsLock);
455 56 : bool success = false;
456 :
457 56 : auto it = std::find_if(m_eventHandlers.begin(), m_eventHandlers.end(),
458 84 : [&](const auto &item) { return item.second.id == eventTag; });
459 56 : if (m_eventHandlers.end() != it)
460 : {
461 56 : m_eventHandlers.erase(it);
462 56 : success = true;
463 : }
464 :
465 56 : return success;
466 : }
467 :
468 : // -----------------------------------------------------------------------------
469 : /*!
470 : \internal
471 :
472 :
473 : */
474 24 : bool ChannelImpl::processSocketEvent()
475 : {
476 : static std::mutex bufLock;
477 24 : std::lock_guard<std::mutex> bufLocker(bufLock);
478 :
479 28 : static std::vector<uint8_t> dataBuf(kMaxMessageSize);
480 28 : static std::vector<uint8_t> ctrlBuf(CMSG_SPACE(SCM_MAX_FD * sizeof(int)));
481 :
482 : // read all messages from the client socket, we break out if the socket is closed
483 : // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
484 : while (true)
485 : {
486 47 : struct msghdr msg = {nullptr};
487 47 : struct iovec io = {.iov_base = dataBuf.data(), .iov_len = dataBuf.size()};
488 :
489 47 : bzero(&msg, sizeof(msg));
490 47 : msg.msg_iov = &io;
491 47 : msg.msg_iovlen = 1;
492 47 : msg.msg_control = ctrlBuf.data();
493 47 : msg.msg_controllen = ctrlBuf.size();
494 :
495 : // read one message
496 47 : ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(m_sock, &msg, MSG_CMSG_CLOEXEC));
497 47 : if (rd < 0)
498 : {
499 23 : if (errno != EWOULDBLOCK)
500 : {
501 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
502 :
503 0 : std::lock_guard<std::mutex> locker(m_lock);
504 0 : disconnectNoLock();
505 0 : return false;
506 : }
507 :
508 23 : break;
509 : }
510 24 : else if (rd == 0)
511 : {
512 : // server closed connection, and we've read all data
513 1 : RIALTO_IPC_LOG_INFO("socket remote end closed, disconnecting channel");
514 :
515 1 : std::lock_guard<std::mutex> locker(m_lock);
516 1 : disconnectNoLock();
517 1 : return false;
518 : }
519 23 : else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
520 : {
521 0 : RIALTO_IPC_LOG_WARN("received truncated message from server, discarding");
522 :
523 : // make sure to close all the fds, otherwise we'll leak them, this
524 : // will read the fds and return in a vector, which will then be
525 : // destroyed, closing all the fds
526 0 : readMessageFds(&msg, 16);
527 : }
528 : else
529 : {
530 : // if there is control data then assume fd(s) have been passed
531 23 : std::vector<FileDescriptor> fds;
532 23 : if (msg.msg_controllen > 0)
533 : {
534 0 : fds = readMessageFds(&msg, 32);
535 : }
536 :
537 : // process the message from the server
538 23 : processServerMessage(dataBuf.data(), rd, &fds);
539 : }
540 : }
541 :
542 23 : return true;
543 13 : }
544 :
545 : // -----------------------------------------------------------------------------
546 : /*!
547 : \internal
548 :
549 : Called from process() to check if the timerfd has expired and if so cancel
550 : the any outstanding method calls that have now timed-out.
551 :
552 : */
553 2 : void ChannelImpl::processTimeoutEvent()
554 : {
555 : // read the timerfd to clear any expirations
556 : uint64_t expirations;
557 2 : ssize_t rd = TEMP_FAILURE_RETRY(read(m_timerFd, &expirations, sizeof(expirations)));
558 2 : if (rd < 0)
559 : {
560 1 : if (errno != EWOULDBLOCK)
561 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading timerfd");
562 1 : return;
563 : }
564 :
565 : // stores the timed-out method calls
566 1 : std::vector<MethodCall> timedOuts;
567 :
568 : {
569 : // check if any method call has now expired
570 1 : std::unique_lock<std::mutex> locker(m_lock);
571 :
572 : // remove the method calls that have expired
573 1 : const auto kNow = std::chrono::steady_clock::now();
574 1 : auto it = m_methodCalls.begin();
575 2 : while (it != m_methodCalls.end())
576 : {
577 1 : if (kNow >= it->second.timeoutDeadline)
578 : {
579 1 : timedOuts.emplace_back(it->second);
580 1 : it = m_methodCalls.erase(it);
581 : }
582 : else
583 : {
584 0 : ++it;
585 : }
586 : }
587 :
588 : // if we still have method calls available, then re-calculate the timer for the next timeout
589 1 : if (!m_methodCalls.empty())
590 : {
591 0 : updateTimeoutTimer();
592 : }
593 1 : }
594 :
595 2 : for (auto &call : timedOuts)
596 : {
597 2 : completeWithError(&call, "Timed out");
598 : }
599 1 : }
600 :
601 : // -----------------------------------------------------------------------------
602 : /*!
603 : \internal
604 :
605 : Called from process() when the eventfd was used to wake the event loop. All
606 : the function does is read the eventfd to clear it's value.
607 :
608 : */
609 0 : void ChannelImpl::processWakeEvent()
610 : {
611 : uint64_t ignore;
612 0 : if (TEMP_FAILURE_RETRY(read(m_eventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
613 : {
614 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "Failed to read wake eventfd to clear it");
615 : }
616 : }
617 :
618 : // -----------------------------------------------------------------------------
619 : /*!
620 : \internal
621 :
622 : Updates the timerfd to the time of the next method call timeout. If method
623 : calls are pending then the timer is disabled.
624 :
625 : This should be called whenever a new method is called or a method has
626 : completed.
627 :
628 : \note Must be called while holding the m_lock mutex.
629 :
630 : */
631 39 : void ChannelImpl::updateTimeoutTimer()
632 : {
633 39 : struct itimerspec ts = {{0}};
634 :
635 : // if no method calls then just disarm the timer
636 39 : if (!m_methodCalls.empty())
637 : {
638 : // otherwise, find the next soonest timeout
639 20 : std::chrono::steady_clock::time_point nextTimeout = std::chrono::steady_clock::time_point::max();
640 : auto nextTimeoutCall =
641 20 : std::min_element(m_methodCalls.begin(), m_methodCalls.end(), [](const auto &elem, const auto ¤tMin)
642 0 : { return elem.second.timeoutDeadline < currentMin.second.timeoutDeadline; });
643 20 : if (nextTimeoutCall != m_methodCalls.end())
644 : {
645 20 : nextTimeout = nextTimeoutCall->second.timeoutDeadline;
646 : }
647 :
648 : // set the timerfd to the next duration
649 : const std::chrono::microseconds kDuration =
650 20 : std::chrono::duration_cast<std::chrono::microseconds>(nextTimeout - std::chrono::steady_clock::now());
651 20 : if (kDuration <= std::chrono::microseconds::zero())
652 : {
653 0 : ts.it_value.tv_nsec = 1000;
654 : }
655 : else
656 : {
657 20 : ts.it_value.tv_sec = static_cast<time_t>(std::chrono::duration_cast<std::chrono::seconds>(kDuration).count());
658 20 : ts.it_value.tv_nsec = static_cast<int32_t>((kDuration.count() % 1000000) * 1000);
659 : }
660 :
661 20 : RIALTO_IPC_LOG_DEBUG("next timeout in %" PRId64 "us - %ld.%09lds", kDuration.count(), ts.it_value.tv_sec,
662 : ts.it_value.tv_nsec);
663 : }
664 :
665 : // write the timeout value
666 39 : if (timerfd_settime(m_timerFd, 0, &ts, nullptr) != 0)
667 : {
668 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to timerfd");
669 : }
670 39 : }
671 :
672 : // -----------------------------------------------------------------------------
673 : /*!
674 : \internal
675 :
676 : Processes a single message from the server, it may be a method call response
677 : or an event.
678 :
679 :
680 : */
681 23 : void ChannelImpl::processServerMessage(const uint8_t *data, size_t dataLen, std::vector<FileDescriptor> *fds)
682 : {
683 : // parse the message
684 23 : transport::MessageFromServer message;
685 23 : if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
686 : {
687 0 : RIALTO_IPC_LOG_ERROR("invalid message from server");
688 0 : return;
689 : }
690 :
691 : // check if an event or a reply to a request
692 23 : if (message.has_reply())
693 : {
694 13 : processReplyFromServer(message.reply(), fds);
695 : }
696 10 : else if (message.has_error())
697 : {
698 6 : processErrorFromServer(message.error());
699 : }
700 4 : else if (message.has_event())
701 : {
702 4 : processEventFromServer(message.event(), fds);
703 : }
704 : else
705 : {
706 0 : RIALTO_IPC_LOG_ERROR("message from server is missing reply or event type");
707 : }
708 23 : }
709 :
710 : // -----------------------------------------------------------------------------
711 : /*!
712 : \internal
713 :
714 :
715 : */
716 13 : void ChannelImpl::processReplyFromServer(const transport::MethodCallReply &reply, std::vector<FileDescriptor> *fds)
717 : {
718 13 : RIALTO_IPC_LOG_DEBUG("processing reply from server");
719 :
720 13 : MethodCall methodCall;
721 13 : const uint64_t kSerialId = reply.reply_id();
722 :
723 : {
724 13 : std::lock_guard<std::mutex> locker(m_lock);
725 :
726 13 : auto it = m_methodCalls.find(kSerialId);
727 13 : if (it == m_methodCalls.end())
728 : {
729 0 : RIALTO_IPC_LOG_ERROR("failed to find request for received reply with id %" PRIu64 "", reply.reply_id());
730 0 : return;
731 : }
732 :
733 13 : methodCall = it->second;
734 13 : m_methodCalls.erase(it);
735 :
736 13 : updateTimeoutTimer();
737 : }
738 :
739 13 : if (!methodCall.response->ParseFromString(reply.reply_message()))
740 : {
741 0 : RIALTO_IPC_LOG_ERROR("failed to parse method reply from server");
742 0 : completeWithError(&methodCall, "Failed to parse reply message");
743 : }
744 13 : else if (!addReplyFileDescriptors(methodCall.response, fds))
745 : {
746 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the reply");
747 0 : completeWithError(&methodCall, "Mismatched file descriptors in message");
748 : }
749 13 : else if (methodCall.closure)
750 : {
751 13 : RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - %s { %s }", kSerialId,
752 : methodCall.response->GetTypeName().c_str(), methodCall.response->ShortDebugString().c_str());
753 :
754 13 : complete(&methodCall);
755 : }
756 : }
757 :
758 : // -----------------------------------------------------------------------------
759 : /*!
760 : \internal
761 :
762 :
763 : */
764 6 : void ChannelImpl::processErrorFromServer(const transport::MethodCallError &error)
765 : {
766 6 : RIALTO_IPC_LOG_DEBUG("processing error from server");
767 :
768 6 : std::unique_lock<std::mutex> locker(m_lock);
769 :
770 : // find the original request
771 6 : const uint64_t kSerialId = error.reply_id();
772 6 : auto it = m_methodCalls.find(kSerialId);
773 6 : if (it == m_methodCalls.end())
774 : {
775 0 : RIALTO_IPC_LOG_ERROR("failed to find request for received reply with id %" PRIu64 "", error.reply_id());
776 0 : return;
777 : }
778 :
779 : // take the method call and remove from the map of outstanding calls
780 6 : MethodCall methodCall = it->second;
781 6 : m_methodCalls.erase(it);
782 :
783 : // update the timeout timer now a method call has been processed
784 6 : updateTimeoutTimer();
785 :
786 : // can now drop the lock
787 6 : locker.unlock();
788 :
789 6 : RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - %s", kSerialId, error.error_reason().c_str());
790 :
791 : // complete the call with an error
792 6 : completeWithError(&methodCall, error.error_reason());
793 : }
794 :
795 : // -----------------------------------------------------------------------------
796 : /*!
797 : \internal
798 :
799 :
800 : */
801 4 : void ChannelImpl::processEventFromServer(const transport::EventFromServer &event, std::vector<FileDescriptor> *fds)
802 : {
803 4 : RIALTO_IPC_LOG_DEBUG("processing event from server");
804 :
805 4 : const std::string &kEventName = event.event_name();
806 :
807 4 : std::lock_guard<std::mutex> locker(m_eventsLock);
808 :
809 4 : auto range = m_eventHandlers.equal_range(kEventName);
810 4 : if (range.first == range.second)
811 : {
812 0 : RIALTO_IPC_LOG_WARN("no handler for event %s", kEventName.c_str());
813 0 : return;
814 : }
815 :
816 4 : const google::protobuf::Descriptor *kDescriptor = range.first->second.descriptor;
817 :
818 : const google::protobuf::Message *kPrototype =
819 4 : google::protobuf::MessageFactory::generated_factory()->GetPrototype(kDescriptor);
820 4 : if (!kPrototype)
821 : {
822 0 : RIALTO_IPC_LOG_ERROR("failed to create prototype for event %s", kEventName.c_str());
823 0 : return;
824 : }
825 :
826 4 : std::shared_ptr<google::protobuf::Message> message(kPrototype->New());
827 4 : if (!message)
828 : {
829 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable message from prototype");
830 0 : return;
831 : }
832 :
833 4 : if (!message->ParseFromString(event.message()))
834 : {
835 0 : RIALTO_IPC_LOG_ERROR("failed to parse message for event %s", kEventName.c_str());
836 : }
837 4 : else if (!addReplyFileDescriptors(message.get(), fds))
838 : {
839 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the reply");
840 : }
841 : else
842 : {
843 4 : RIALTO_IPC_LOG_DEBUG("event{ %s } - %s { %s }", kEventName.c_str(), message->GetTypeName().c_str(),
844 : message->ShortDebugString().c_str());
845 :
846 8 : for (auto it = range.first; it != range.second; ++it)
847 : {
848 4 : it->second.handler(message);
849 : }
850 : }
851 : }
852 :
853 : // -----------------------------------------------------------------------------
854 : /*!
855 : \internal
856 : \static
857 :
858 : Reads all the file descriptors from a unix domain socket received \a msg.
859 : It returns the file descriptors as a vector of FileDescriptor objects,
860 : these objects safely store the fd and close them when they're destructed.
861 :
862 : The \a limit specifies the maximum number of fds to store, if more were
863 : sent then they are automatically closed and not returned in the vector.
864 :
865 : */
866 0 : std::vector<FileDescriptor> ChannelImpl::readMessageFds(const struct msghdr *msg, size_t limit)
867 : {
868 0 : std::vector<FileDescriptor> fds;
869 :
870 0 : for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr;
871 0 : cmsg = CMSG_NXTHDR(const_cast<struct msghdr *>(msg), cmsg))
872 : {
873 0 : if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
874 : {
875 0 : const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
876 0 : if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
877 : {
878 0 : RIALTO_IPC_LOG_ERROR("invalid fd array size");
879 : }
880 : else
881 : {
882 0 : const size_t n = kFdsLength / sizeof(int);
883 0 : RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
884 :
885 0 : fds.reserve(std::min(limit, n));
886 :
887 0 : const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
888 0 : for (size_t i = 0; i < n; i++)
889 : {
890 0 : RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
891 :
892 0 : if (fds.size() >= limit)
893 : {
894 0 : RIALTO_IPC_LOG_ERROR(
895 : "received to many file descriptors, exceeding max per message, closing left overs");
896 : }
897 : else
898 : {
899 0 : firebolt::rialto::ipc::FileDescriptor fileDescriptor(kFds[i]);
900 0 : if (!fileDescriptor.isValid())
901 : {
902 0 : RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
903 : }
904 : else
905 : {
906 0 : fds.emplace_back(std::move(fileDescriptor));
907 : }
908 : }
909 :
910 0 : if (close(kFds[i]) != 0)
911 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
912 : }
913 : }
914 : }
915 : }
916 :
917 0 : return fds;
918 : }
919 :
920 : // -----------------------------------------------------------------------------
921 : /*!
922 : \static
923 : \internal
924 :
925 : Places the received file descriptors into the protobuf message.
926 :
927 : It works by iterating over the fields in the message, finding ones that are
928 : marked as 'field_is_fd' and then replacing the received integer value with
929 : an actual file descriptor.
930 :
931 : */
932 17 : bool ChannelImpl::addReplyFileDescriptors(google::protobuf::Message *reply,
933 : std::vector<firebolt::rialto::ipc::FileDescriptor> *fds)
934 : {
935 17 : auto fdIterator = fds->begin();
936 :
937 17 : const google::protobuf::Descriptor *kDescriptor = reply->GetDescriptor();
938 17 : const google::protobuf::Reflection *kReflection = nullptr;
939 :
940 17 : const int n = kDescriptor->field_count();
941 34 : for (int i = 0; i < n; i++)
942 : {
943 17 : auto fieldDescriptor = kDescriptor->field(i);
944 17 : if (fieldDescriptor->options().HasExtension(::firebolt::rialto::ipc::field_is_fd) &&
945 0 : fieldDescriptor->options().GetExtension(::firebolt::rialto::ipc::field_is_fd))
946 : {
947 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
948 : {
949 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
950 0 : return false;
951 : }
952 :
953 0 : if (!kReflection)
954 : {
955 0 : kReflection = reply->GetReflection();
956 : }
957 :
958 0 : if (kReflection->HasField(*reply, fieldDescriptor))
959 : {
960 0 : if (fdIterator == fds->end())
961 : {
962 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but none or too few were supplied");
963 0 : return false;
964 : }
965 :
966 0 : kReflection->SetInt32(reply, fieldDescriptor, fdIterator->fd());
967 0 : ++fdIterator;
968 : }
969 : }
970 : }
971 :
972 17 : if (fdIterator != fds->end())
973 : {
974 0 : RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
975 0 : return false;
976 : }
977 :
978 : // we now need to release all the fds stored in the vector, otherwise they
979 : // will be closed when the vector is destroyed. From now onwards it is the
980 : // caller's responsibility to close the fds in the returned protobuf message
981 : // object
982 17 : for (firebolt::rialto::ipc::FileDescriptor &fd : *fds)
983 : {
984 0 : fd.release();
985 : }
986 :
987 17 : return true;
988 : }
989 :
990 : // -----------------------------------------------------------------------------
991 : /*!
992 : \internal
993 : \static
994 :
995 :
996 : */
997 13 : void ChannelImpl::complete(MethodCall *call)
998 : {
999 13 : if (call->closure)
1000 : {
1001 13 : call->closure->Run();
1002 13 : call->closure = nullptr;
1003 : }
1004 : }
1005 :
1006 : // -----------------------------------------------------------------------------
1007 : /*!
1008 : \internal
1009 : \static
1010 :
1011 :
1012 : */
1013 7 : void ChannelImpl::completeWithError(MethodCall *call, std::string reason)
1014 : {
1015 7 : RIALTO_IPC_LOG_DEBUG("completing method call with error '%s'", reason.c_str());
1016 :
1017 7 : if (call->controller)
1018 : {
1019 7 : call->controller->setMethodCallFailed(std::move(reason));
1020 7 : call->controller = nullptr;
1021 : }
1022 :
1023 7 : if (call->closure)
1024 : {
1025 7 : call->closure->Run();
1026 7 : call->closure = nullptr;
1027 : }
1028 : }
1029 :
1030 : // -----------------------------------------------------------------------------
1031 : /*!
1032 : \static
1033 :
1034 : Iterates through the message and finds any file descriptor type fields, if
1035 : found it adds the fds to the returned vector.
1036 :
1037 : */
1038 21 : std::vector<int> ChannelImpl::getMessageFds(const google::protobuf::Message &message)
1039 : {
1040 21 : std::vector<int> fds;
1041 :
1042 21 : auto descriptor = message.GetDescriptor();
1043 21 : const int n = descriptor->field_count();
1044 79 : for (int i = 0; i < n; i++)
1045 : {
1046 58 : auto fieldDescriptor = descriptor->field(i);
1047 58 : if (fieldDescriptor->options().HasExtension(::firebolt::rialto::ipc::field_is_fd) &&
1048 0 : fieldDescriptor->options().GetExtension(::firebolt::rialto::ipc::field_is_fd))
1049 : {
1050 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
1051 : {
1052 0 : RIALTO_IPC_LOG_ERROR("field '%s' is marked as containing an fd but not an int32 type",
1053 : fieldDescriptor->full_name().c_str());
1054 : }
1055 : else
1056 : {
1057 0 : auto reflection = message.GetReflection();
1058 0 : int fileDescriptor = reflection->GetInt32(message, fieldDescriptor);
1059 0 : fds.emplace_back(fileDescriptor);
1060 : }
1061 : }
1062 : }
1063 :
1064 21 : return fds;
1065 : }
1066 :
1067 : // -----------------------------------------------------------------------------
1068 : /*!
1069 : \overload
1070 :
1071 : \quote
1072 : Call the given method of the remote service. The signature of this
1073 : procedure looks the same as Service::CallMethod(), but the requirements
1074 : are less strict in one important way: the request and response objects
1075 : need not be of any specific class as long as their descriptors are
1076 : method->input_type() and method->output_type().
1077 :
1078 : */
1079 21 : void ChannelImpl::CallMethod(const google::protobuf::MethodDescriptor *method, // NOLINT(build/function_format)
1080 : google::protobuf::RpcController *controller, const google::protobuf::Message *request,
1081 : google::protobuf::Message *response, google::protobuf::Closure *done)
1082 : {
1083 21 : MethodCall methodCall{std::chrono::steady_clock::now() + m_timeout,
1084 21 : dynamic_cast<ClientControllerImpl *>(controller), response, done};
1085 :
1086 : //
1087 21 : const uint64_t kSerialId = m_serialCounter++;
1088 :
1089 : // create the transport request
1090 21 : transport::MessageToServer message;
1091 21 : transport::MethodCall *call = message.mutable_call();
1092 21 : call->set_serial_id(kSerialId);
1093 21 : call->set_service_name(method->service()->full_name());
1094 21 : call->set_method_name(method->name());
1095 :
1096 : // copy in the actual message data
1097 21 : std::string reqString = request->SerializeAsString();
1098 21 : call->set_request_message(std::move(reqString));
1099 :
1100 21 : const size_t kRequiredDataLen = message.ByteSizeLong();
1101 21 : if (kRequiredDataLen > kMaxMessageSize)
1102 : {
1103 0 : RIALTO_IPC_LOG_ERROR("method call to big to send (%zu, max %zu", kRequiredDataLen, kMaxMessageSize);
1104 0 : completeWithError(&methodCall, "Method call to big");
1105 0 : return;
1106 : }
1107 :
1108 : // extract the fds from the message
1109 21 : const std::vector<int> kFds = getMessageFds(*request);
1110 21 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1111 :
1112 : // build the socket message to send
1113 : auto msgBuf =
1114 21 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
1115 :
1116 21 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1117 21 : bzero(header, sizeof(msghdr));
1118 :
1119 21 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1120 21 : header->msg_control = ctrl;
1121 21 : header->msg_controllen = kRequiredCtrlLen;
1122 :
1123 21 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1124 21 : header->msg_iov = iov;
1125 21 : header->msg_iovlen = 1;
1126 :
1127 21 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1128 21 : iov->iov_base = data;
1129 21 : iov->iov_len = kRequiredDataLen;
1130 :
1131 : // copy in the data
1132 21 : message.SerializeWithCachedSizesToArray(data);
1133 :
1134 : // next check if the request is sending any fd's
1135 21 : if (!kFds.empty())
1136 : {
1137 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1138 0 : if (!cmsg)
1139 : {
1140 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1141 0 : completeWithError(&methodCall, "Internal error");
1142 0 : return;
1143 : }
1144 :
1145 0 : cmsg->cmsg_level = SOL_SOCKET;
1146 0 : cmsg->cmsg_type = SCM_RIGHTS;
1147 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1148 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1149 0 : header->msg_controllen = cmsg->cmsg_len;
1150 : }
1151 :
1152 : // check if the method is expecting a reply
1153 22 : const bool kNoReplyExpected = method->options().HasExtension(::firebolt::rialto::ipc::no_reply) &&
1154 1 : method->options().GetExtension(::firebolt::rialto::ipc::no_reply);
1155 :
1156 : // finally, send the message
1157 21 : std::unique_lock<std::mutex> locker(m_lock);
1158 :
1159 21 : if (m_sock < 0)
1160 : {
1161 0 : locker.unlock();
1162 0 : completeWithError(&methodCall, "Not connected");
1163 : }
1164 21 : else if (sendmsg(m_sock, header, MSG_NOSIGNAL) != static_cast<ssize_t>(kRequiredDataLen))
1165 : {
1166 0 : locker.unlock();
1167 0 : completeWithError(&methodCall, "Failed to send message");
1168 : }
1169 : else
1170 : {
1171 21 : RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", kSerialId, call->service_name().c_str(),
1172 : call->method_name().c_str(), request->ShortDebugString().c_str());
1173 :
1174 21 : if (kNoReplyExpected)
1175 : {
1176 : // no reply from server is expected, however if the caller supplied
1177 : // a closure (it shouldn't) we should still call it now to indicate
1178 : // the method call has been made
1179 1 : if (done)
1180 1 : done->Run();
1181 : }
1182 : else
1183 : {
1184 : // add the message to the queue so we pick-up the reply
1185 20 : m_methodCalls.emplace(kSerialId, methodCall);
1186 :
1187 : // update the single timeout timer
1188 20 : updateTimeoutTimer();
1189 : }
1190 : }
1191 11 : }
1192 :
1193 58 : int ChannelImpl::subscribeImpl(const std::string &kEventName, const google::protobuf::Descriptor *descriptor,
1194 : EventHandler &&handler)
1195 : {
1196 58 : std::lock_guard<std::mutex> locker(m_eventsLock);
1197 :
1198 58 : const int kTag = m_eventTagCounter++;
1199 58 : m_eventHandlers.emplace(kEventName, Event{kTag, descriptor, std::move(handler)});
1200 :
1201 58 : return kTag;
1202 : }
1203 : }; // namespace firebolt::rialto::ipc
|