Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #include <algorithm>
21 : #include <cinttypes>
22 : #include <cstdarg>
23 : #include <memory>
24 : #include <utility>
25 :
26 : #include <fcntl.h>
27 : #include <poll.h>
28 : #include <sys/epoll.h>
29 : #include <sys/eventfd.h>
30 : #include <sys/socket.h>
31 : #include <sys/timerfd.h>
32 : #include <sys/un.h>
33 : #include <unistd.h>
34 :
35 : #include "IpcChannelImpl.h"
36 : #include "IpcLogging.h"
37 : #include "rialtoipc.pb.h"
38 :
39 : #if !defined(SCM_MAX_FD)
40 : #define SCM_MAX_FD 255
41 : #endif
42 :
43 : namespace
44 : {
45 : constexpr size_t kMaxMessageSize{128 * 1024};
46 : const std::chrono::milliseconds kDefaultIpcTimeout{3000};
47 :
48 27 : std::chrono::milliseconds getIpcTimeout()
49 : {
50 27 : const char *kCustomTimeout = getenv("RIALTO_CLIENT_IPC_TIMEOUT");
51 27 : std::chrono::milliseconds timeout{kDefaultIpcTimeout};
52 27 : if (kCustomTimeout)
53 : {
54 : try
55 : {
56 0 : timeout = std::chrono::milliseconds{std::stoull(kCustomTimeout)};
57 0 : RIALTO_IPC_LOG_INFO("Using custom Ipc timeout: %sms", kCustomTimeout);
58 : }
59 0 : catch (const std::exception &e)
60 : {
61 0 : RIALTO_IPC_LOG_ERROR("Custom Ipc timeout invalid, ignoring: %s", kCustomTimeout);
62 : }
63 : }
64 27 : return timeout;
65 : }
66 : } // namespace
67 :
68 : namespace firebolt::rialto::ipc
69 : {
70 :
71 27 : std::shared_ptr<IChannelFactory> IChannelFactory::createFactory()
72 : {
73 27 : std::shared_ptr<IChannelFactory> factory;
74 : try
75 : {
76 27 : factory = std::make_shared<ChannelFactory>();
77 : }
78 0 : catch (const std::exception &e)
79 : {
80 0 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel factory, reason: %s", e.what());
81 : }
82 :
83 27 : return factory;
84 : }
85 :
86 15 : std::shared_ptr<IChannel> ChannelFactory::createChannel(int sockFd)
87 : {
88 15 : std::shared_ptr<IChannel> channel;
89 : try
90 : {
91 15 : channel = std::make_shared<ChannelImpl>(sockFd);
92 : }
93 1 : catch (const std::exception &e)
94 : {
95 1 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel with socketFd %d, reason: %s", sockFd, e.what());
96 : }
97 :
98 15 : return channel;
99 : }
100 :
101 12 : std::shared_ptr<IChannel> ChannelFactory::createChannel(const std::string &socketPath)
102 : {
103 12 : std::shared_ptr<IChannel> channel;
104 : try
105 : {
106 12 : channel = std::make_shared<ChannelImpl>(socketPath);
107 : }
108 3 : catch (const std::exception &e)
109 : {
110 3 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel with socketPath %s, reason: %s", socketPath.c_str(),
111 : e.what());
112 : }
113 :
114 12 : return channel;
115 : }
116 :
117 15 : ChannelImpl::ChannelImpl(int sock)
118 15 : : m_sock(-1), m_epollFd(-1), m_timerFd(-1), m_eventFd(-1), m_serialCounter(1), m_timeout(getIpcTimeout()),
119 30 : m_eventTagCounter(1)
120 : {
121 15 : if (!attachSocket(sock))
122 : {
123 1 : throw std::runtime_error("Failed attach the socket");
124 : }
125 14 : if (!initChannel())
126 : {
127 0 : termChannel();
128 0 : throw std::runtime_error("Failed to initalise the channel");
129 : }
130 14 : if (!isConnectedInternal())
131 : {
132 0 : termChannel();
133 0 : throw std::runtime_error("Channel not connected");
134 : }
135 18 : }
136 :
137 12 : ChannelImpl::ChannelImpl(const std::string &socketPath)
138 12 : : m_sock(-1), m_epollFd(-1), m_timerFd(-1), m_eventFd(-1), m_serialCounter(1), m_timeout(getIpcTimeout()),
139 24 : m_eventTagCounter(1)
140 : {
141 12 : if (!createConnectedSocket(socketPath))
142 : {
143 3 : throw std::runtime_error("Failed connect socket");
144 : }
145 9 : if (!initChannel())
146 : {
147 0 : termChannel();
148 0 : throw std::runtime_error("Failed to initalise the channel");
149 : }
150 9 : if (!isConnectedInternal())
151 : {
152 0 : termChannel();
153 0 : throw std::runtime_error("Channel not connected");
154 : }
155 21 : }
156 :
157 23 : ChannelImpl::~ChannelImpl()
158 : {
159 23 : termChannel();
160 : }
161 :
162 12 : bool ChannelImpl::createConnectedSocket(const std::string &socketPath)
163 : {
164 12 : int sock = socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
165 12 : if (sock < 0)
166 : {
167 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create socket");
168 0 : return false;
169 : }
170 :
171 12 : struct sockaddr_un addr = {0};
172 12 : memset(&addr, 0x00, sizeof(addr));
173 12 : addr.sun_family = AF_UNIX;
174 12 : strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
175 :
176 12 : if (::connect(sock, (struct sockaddr *)&addr, sizeof(addr)) == -1)
177 : {
178 3 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to connect to %s", socketPath.c_str());
179 3 : close(sock);
180 3 : return false;
181 : }
182 :
183 9 : m_sock = sock;
184 :
185 9 : return true;
186 : }
187 :
188 15 : bool ChannelImpl::attachSocket(int sockFd)
189 : {
190 : // sanity check the supplied socket is of the right type
191 : struct sockaddr addr;
192 15 : socklen_t len = sizeof(addr);
193 15 : if ((getsockname(sockFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
194 : {
195 1 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
196 1 : return false;
197 : }
198 14 : if (addr.sa_family != AF_UNIX)
199 : {
200 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
201 0 : return false;
202 : }
203 :
204 14 : int type = 0;
205 14 : len = sizeof(type);
206 14 : if ((getsockopt(sockFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
207 : {
208 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
209 0 : return false;
210 : }
211 14 : if (type != SOCK_SEQPACKET)
212 : {
213 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
214 0 : return false;
215 : }
216 :
217 : // set the O_NONBLOCKING flag on the socket
218 14 : int flags = fcntl(sockFd, F_GETFL);
219 14 : if ((flags < 0) || (fcntl(sockFd, F_SETFL, flags | O_NONBLOCK) < 0))
220 : {
221 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
222 0 : return false;
223 : }
224 :
225 14 : m_sock = sockFd;
226 :
227 14 : return true;
228 : }
229 :
230 23 : bool ChannelImpl::initChannel()
231 : {
232 : // create epoll so can listen for timeouts as well as socket messages
233 23 : m_epollFd = epoll_create1(EPOLL_CLOEXEC);
234 23 : if (m_epollFd < 0)
235 : {
236 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll failed");
237 0 : return false;
238 : }
239 :
240 : // add the socket to epoll
241 23 : epoll_event sockEvent = {.events = EPOLLIN, .data = {.fd = m_sock}};
242 23 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_sock, &sockEvent) != 0)
243 : {
244 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
245 0 : return false;
246 : }
247 :
248 23 : m_timerFd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC | TFD_NONBLOCK);
249 23 : if (m_timerFd < 0)
250 : {
251 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "timerfd_create failed");
252 0 : return false;
253 : }
254 :
255 : // add the timer event to epoll
256 23 : epoll_event timerEvent = {.events = EPOLLIN, .data = {.fd = m_timerFd}};
257 23 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_timerFd, &timerEvent) != 0)
258 : {
259 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
260 0 : return false;
261 : }
262 :
263 : // and lastly the eventfd to wake the poll loop
264 23 : m_eventFd = eventfd(0, EFD_CLOEXEC);
265 23 : if (m_eventFd < 0)
266 : {
267 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd create failed");
268 0 : return false;
269 : }
270 :
271 : // add the timer event to epoll
272 23 : epoll_event wakeEvent = {.events = EPOLLIN, .data = {.fd = m_eventFd}};
273 23 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_eventFd, &wakeEvent) != 0)
274 : {
275 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
276 0 : return false;
277 : }
278 :
279 23 : return true;
280 : }
281 :
282 23 : void ChannelImpl::termChannel()
283 : {
284 : // close the socket and the epoll and timer fds
285 23 : if (m_sock >= 0)
286 0 : disconnectNoLock();
287 23 : if ((m_epollFd >= 0) && (close(m_epollFd) != 0))
288 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing epoll fd failed");
289 23 : if ((m_timerFd >= 0) && (close(m_timerFd) != 0))
290 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing timer fd failed");
291 23 : if ((m_eventFd >= 0) && (close(m_eventFd) != 0))
292 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing event fd failed");
293 :
294 : // if any method calls are still outstanding then complete them with errors now
295 23 : for (auto &entry : m_methodCalls)
296 : {
297 0 : completeWithError(&entry.second, "Channel destructed");
298 : }
299 :
300 23 : m_methodCalls.clear();
301 : }
302 :
303 23 : void ChannelImpl::disconnect()
304 : {
305 : // disconnect from the socket
306 : {
307 23 : std::lock_guard<std::mutex> locker(m_lock);
308 :
309 23 : if (m_sock < 0)
310 1 : return;
311 :
312 22 : disconnectNoLock();
313 14 : }
314 :
315 : // wake the wait(...) call so client code is blocked there it is woken
316 22 : if (m_eventFd >= 0)
317 : {
318 22 : uint64_t wakeup = 1;
319 22 : if (TEMP_FAILURE_RETRY(write(m_eventFd, &wakeup, sizeof(wakeup))) != sizeof(wakeup))
320 : {
321 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to wake event fd");
322 : }
323 : }
324 : }
325 :
326 : // -----------------------------------------------------------------------------
327 : /*!
328 : \internal
329 :
330 :
331 : */
332 23 : void ChannelImpl::disconnectNoLock()
333 : {
334 23 : if (m_sock < 0)
335 : {
336 0 : RIALTO_IPC_LOG_WARN("not connected\n");
337 0 : return;
338 : }
339 :
340 : // remove the socket from epoll
341 23 : if (epoll_ctl(m_epollFd, EPOLL_CTL_DEL, m_sock, nullptr) != 0)
342 : {
343 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to remove socket");
344 : }
345 :
346 : // shutdown and close the socket
347 23 : if (shutdown(m_sock, SHUT_RDWR) != 0)
348 : {
349 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "shutdown error");
350 : }
351 23 : if (close(m_sock) != 0)
352 : {
353 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "close error");
354 : }
355 :
356 23 : m_sock = -1;
357 : }
358 :
359 12686 : bool ChannelImpl::isConnected() const
360 : {
361 12686 : return isConnectedInternal();
362 : }
363 :
364 12709 : bool ChannelImpl::isConnectedInternal() const
365 : {
366 12709 : std::lock_guard<std::mutex> locker(m_lock);
367 12709 : return (m_sock >= 0);
368 : }
369 :
370 0 : int ChannelImpl::fd() const
371 : {
372 0 : return m_epollFd;
373 : }
374 :
375 3159 : bool ChannelImpl::wait(int timeoutMSecs)
376 : {
377 3159 : if ((m_epollFd < 0) || !isConnected())
378 : {
379 0 : return false;
380 : }
381 :
382 : // wait for any event (with timeout)
383 : struct pollfd fds[2];
384 3159 : fds[0].fd = m_epollFd;
385 3159 : fds[0].events = POLLIN;
386 :
387 3159 : int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
388 3159 : if (rc < 0)
389 : {
390 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
391 0 : return false;
392 : }
393 :
394 3159 : return isConnected();
395 : }
396 :
397 3183 : bool ChannelImpl::process()
398 : {
399 3183 : if (!isConnected())
400 14 : return false;
401 :
402 : struct epoll_event events[3];
403 3169 : int rc = TEMP_FAILURE_RETRY(epoll_wait(m_epollFd, events, 3, 0));
404 3169 : if (rc < 0)
405 : {
406 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
407 0 : return false;
408 : }
409 :
410 : enum
411 : {
412 : HaveSocketEvent = 0x1,
413 : HaveTimeoutEvent = 0x2,
414 : HaveWakeEvent = 0x4
415 : };
416 3169 : unsigned eventsMask = 0;
417 3189 : for (int i = 0; i < rc; i++)
418 : {
419 20 : if (events[i].data.fd == m_sock)
420 18 : eventsMask |= HaveSocketEvent;
421 2 : else if (events[i].data.fd == m_timerFd)
422 2 : eventsMask |= HaveTimeoutEvent;
423 0 : else if (events[i].data.fd == m_eventFd)
424 0 : eventsMask |= HaveWakeEvent;
425 : }
426 :
427 3169 : if ((eventsMask & HaveSocketEvent) && !processSocketEvent())
428 : {
429 1 : std::map<uint64_t, MethodCall> callsToDelete;
430 : {
431 1 : std::lock_guard<std::mutex> locker(m_lock);
432 1 : callsToDelete = m_methodCalls;
433 1 : m_methodCalls.clear();
434 : }
435 :
436 1 : for (auto &entry : callsToDelete)
437 : {
438 0 : completeWithError(&entry.second, "Socket down");
439 : }
440 1 : return false;
441 : }
442 :
443 3168 : if (eventsMask & HaveTimeoutEvent)
444 2 : processTimeoutEvent();
445 :
446 3168 : if (eventsMask & HaveWakeEvent)
447 0 : processWakeEvent();
448 :
449 3168 : return isConnected();
450 : }
451 :
452 16 : bool ChannelImpl::unsubscribe(int eventTag)
453 : {
454 16 : std::lock_guard<std::mutex> locker(m_eventsLock);
455 16 : bool success = false;
456 :
457 24 : for (auto it = m_eventHandlers.begin(); it != m_eventHandlers.end(); it++)
458 : {
459 24 : if (it->second.id == eventTag)
460 : {
461 16 : m_eventHandlers.erase(it);
462 16 : success = true;
463 16 : break;
464 : }
465 : }
466 :
467 16 : return success;
468 : }
469 :
470 : // -----------------------------------------------------------------------------
471 : /*!
472 : \internal
473 :
474 :
475 : */
476 18 : bool ChannelImpl::processSocketEvent()
477 : {
478 : static std::mutex bufLock;
479 18 : std::lock_guard<std::mutex> bufLocker(bufLock);
480 :
481 18 : static std::vector<uint8_t> dataBuf(kMaxMessageSize);
482 18 : static std::vector<uint8_t> ctrlBuf(CMSG_SPACE(SCM_MAX_FD * sizeof(int)));
483 :
484 : // read all messages from the client socket, we break out if the socket is closed
485 : // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
486 : while (true)
487 : {
488 35 : struct msghdr msg = {nullptr};
489 35 : struct iovec io = {.iov_base = dataBuf.data(), .iov_len = dataBuf.size()};
490 :
491 35 : bzero(&msg, sizeof(msg));
492 35 : msg.msg_iov = &io;
493 35 : msg.msg_iovlen = 1;
494 35 : msg.msg_control = ctrlBuf.data();
495 35 : msg.msg_controllen = ctrlBuf.size();
496 :
497 : // read one message
498 35 : ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(m_sock, &msg, MSG_CMSG_CLOEXEC));
499 35 : if (rd < 0)
500 : {
501 17 : if (errno != EWOULDBLOCK)
502 : {
503 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
504 :
505 0 : std::lock_guard<std::mutex> locker(m_lock);
506 0 : disconnectNoLock();
507 0 : return false;
508 : }
509 :
510 17 : break;
511 : }
512 18 : else if (rd == 0)
513 : {
514 : // server closed connection, and we've read all data
515 1 : RIALTO_IPC_LOG_INFO("socket remote end closed, disconnecting channel");
516 :
517 1 : std::lock_guard<std::mutex> locker(m_lock);
518 1 : disconnectNoLock();
519 1 : return false;
520 : }
521 17 : else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
522 : {
523 0 : RIALTO_IPC_LOG_WARN("received truncated message from server, discarding");
524 :
525 : // make sure to close all the fds, otherwise we'll leak them, this
526 : // will read the fds and return in a vector, which will then be
527 : // destroyed, closing all the fds
528 0 : readMessageFds(&msg, 16);
529 : }
530 : else
531 : {
532 : // if there is control data then assume fd(s) have been passed
533 17 : std::vector<FileDescriptor> fds;
534 17 : if (msg.msg_controllen > 0)
535 : {
536 0 : fds = readMessageFds(&msg, 32);
537 : }
538 :
539 : // process the message from the server
540 17 : processServerMessage(dataBuf.data(), rd, &fds);
541 : }
542 : }
543 :
544 17 : return true;
545 11 : }
546 :
547 : // -----------------------------------------------------------------------------
548 : /*!
549 : \internal
550 :
551 : Called from process() to check if the timerfd has expired and if so cancel
552 : the any outstanding method calls that have now timed-out.
553 :
554 : */
555 2 : void ChannelImpl::processTimeoutEvent()
556 : {
557 : // read the timerfd to clear any expirations
558 : uint64_t expirations;
559 2 : ssize_t rd = TEMP_FAILURE_RETRY(read(m_timerFd, &expirations, sizeof(expirations)));
560 2 : if (rd < 0)
561 : {
562 1 : if (errno != EWOULDBLOCK)
563 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading timerfd");
564 1 : return;
565 : }
566 :
567 : // check if any method call has no expired
568 1 : std::unique_lock<std::mutex> locker(m_lock);
569 :
570 : // stores the timed-out method calls
571 1 : std::vector<MethodCall> timedOuts;
572 :
573 : // remove the method calls that have expired
574 1 : const auto kNow = std::chrono::steady_clock::now();
575 1 : auto it = m_methodCalls.begin();
576 2 : while (it != m_methodCalls.end())
577 : {
578 1 : if (kNow >= it->second.timeoutDeadline)
579 : {
580 1 : timedOuts.emplace_back(it->second);
581 1 : it = m_methodCalls.erase(it);
582 : }
583 : else
584 : {
585 0 : ++it;
586 : }
587 : }
588 :
589 : // if we still have method calls available, then re-calculate the timer for the next timeout
590 1 : if (!m_methodCalls.empty())
591 : {
592 0 : updateTimeoutTimer();
593 : }
594 :
595 : // drop the lock and now terminate the timed out method calls
596 1 : locker.unlock();
597 :
598 2 : for (auto &call : timedOuts)
599 : {
600 1 : completeWithError(&call, "Timed out");
601 : }
602 : }
603 :
604 : // -----------------------------------------------------------------------------
605 : /*!
606 : \internal
607 :
608 : Called from process() when the eventfd was used to wake the event loop. All
609 : the function does is read the eventfd to clear it's value.
610 :
611 : */
612 0 : void ChannelImpl::processWakeEvent()
613 : {
614 : uint64_t ignore;
615 0 : if (TEMP_FAILURE_RETRY(read(m_eventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
616 : {
617 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "Failed to read wake eventfd to clear it");
618 : }
619 : }
620 :
621 : // -----------------------------------------------------------------------------
622 : /*!
623 : \internal
624 :
625 : Updates the timerfd to the time of the next method call timeout. If method
626 : calls are pending then the timer is disabled.
627 :
628 : This should be called whenever a new method is called or a method has
629 : completed.
630 :
631 : \note Must be called while holding the m_lock mutex.
632 :
633 : */
634 27 : void ChannelImpl::updateTimeoutTimer()
635 : {
636 27 : struct itimerspec ts = {{0}};
637 :
638 : // if no method calls then just disarm the timer
639 27 : if (!m_methodCalls.empty())
640 : {
641 : // otherwise, find the next soonest timeout
642 14 : std::chrono::steady_clock::time_point nextTimeout = std::chrono::steady_clock::time_point::max();
643 : auto nextTimeoutCall =
644 14 : std::min_element(m_methodCalls.begin(), m_methodCalls.end(),
645 0 : [](const auto &elem, const auto ¤tMin)
646 0 : { return elem.second.timeoutDeadline < currentMin.second.timeoutDeadline; });
647 14 : if (nextTimeoutCall != m_methodCalls.end())
648 : {
649 14 : nextTimeout = nextTimeoutCall->second.timeoutDeadline;
650 : }
651 :
652 : // set the timerfd to the next duration
653 : const std::chrono::microseconds kDuration =
654 14 : std::chrono::duration_cast<std::chrono::microseconds>(nextTimeout - std::chrono::steady_clock::now());
655 14 : if (kDuration <= std::chrono::microseconds::zero())
656 : {
657 0 : ts.it_value.tv_nsec = 1000;
658 : }
659 : else
660 : {
661 14 : ts.it_value.tv_sec = static_cast<time_t>(std::chrono::duration_cast<std::chrono::seconds>(kDuration).count());
662 14 : ts.it_value.tv_nsec = static_cast<int32_t>((kDuration.count() % 1000000) * 1000);
663 : }
664 :
665 14 : RIALTO_IPC_LOG_DEBUG("next timeout in %" PRId64 "us - %ld.%09lds", kDuration.count(), ts.it_value.tv_sec,
666 : ts.it_value.tv_nsec);
667 : }
668 :
669 : // write the timeout value
670 27 : if (timerfd_settime(m_timerFd, 0, &ts, nullptr) != 0)
671 : {
672 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to timerfd");
673 : }
674 27 : }
675 :
676 : // -----------------------------------------------------------------------------
677 : /*!
678 : \internal
679 :
680 : Processes a single message from the server, it may be a method call response
681 : or an event.
682 :
683 :
684 : */
685 17 : void ChannelImpl::processServerMessage(const uint8_t *data, size_t dataLen, std::vector<FileDescriptor> *fds)
686 : {
687 : // parse the message
688 17 : transport::MessageFromServer message;
689 17 : if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
690 : {
691 0 : RIALTO_IPC_LOG_ERROR("invalid message from server");
692 0 : return;
693 : }
694 :
695 : // check if an event or a reply to a request
696 17 : if (message.has_reply())
697 : {
698 8 : processReplyFromServer(message.reply(), fds);
699 : }
700 9 : else if (message.has_error())
701 : {
702 5 : processErrorFromServer(message.error());
703 : }
704 4 : else if (message.has_event())
705 : {
706 4 : processEventFromServer(message.event(), fds);
707 : }
708 : else
709 : {
710 0 : RIALTO_IPC_LOG_ERROR("message from server is missing reply or event type");
711 : }
712 17 : }
713 :
714 : // -----------------------------------------------------------------------------
715 : /*!
716 : \internal
717 :
718 :
719 : */
720 8 : void ChannelImpl::processReplyFromServer(const transport::MethodCallReply &reply, std::vector<FileDescriptor> *fds)
721 : {
722 8 : RIALTO_IPC_LOG_DEBUG("processing reply from server");
723 :
724 8 : std::unique_lock<std::mutex> locker(m_lock);
725 :
726 : // find the original request
727 8 : const uint64_t kSerialId = reply.reply_id();
728 8 : auto it = m_methodCalls.find(kSerialId);
729 8 : if (it == m_methodCalls.end())
730 : {
731 0 : RIALTO_IPC_LOG_ERROR("failed to find request for received reply with id %" PRIu64 "", reply.reply_id());
732 0 : return;
733 : }
734 :
735 : // take the method call and remove from the map of outstanding calls
736 8 : MethodCall methodCall = it->second;
737 8 : m_methodCalls.erase(it);
738 :
739 : // update the timeout timer now a method call has been processed
740 8 : updateTimeoutTimer();
741 :
742 : // can now drop the lock
743 8 : locker.unlock();
744 :
745 : // this is an actual reply so try and read it
746 8 : if (!methodCall.response->ParseFromString(reply.reply_message()))
747 : {
748 0 : RIALTO_IPC_LOG_ERROR("failed to parse method reply from server");
749 0 : completeWithError(&methodCall, "Failed to parse reply message");
750 : }
751 8 : else if (!addReplyFileDescriptors(methodCall.response, fds))
752 : {
753 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the reply");
754 0 : completeWithError(&methodCall, "Mismatched file descriptors in message");
755 : }
756 8 : else if (methodCall.closure)
757 : {
758 8 : RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - %s { %s }", kSerialId,
759 : methodCall.response->GetTypeName().c_str(), methodCall.response->ShortDebugString().c_str());
760 :
761 8 : complete(&methodCall);
762 : }
763 : }
764 :
765 : // -----------------------------------------------------------------------------
766 : /*!
767 : \internal
768 :
769 :
770 : */
771 5 : void ChannelImpl::processErrorFromServer(const transport::MethodCallError &error)
772 : {
773 5 : RIALTO_IPC_LOG_DEBUG("processing error from server");
774 :
775 5 : std::unique_lock<std::mutex> locker(m_lock);
776 :
777 : // find the original request
778 5 : const uint64_t kSerialId = error.reply_id();
779 5 : auto it = m_methodCalls.find(kSerialId);
780 5 : if (it == m_methodCalls.end())
781 : {
782 0 : RIALTO_IPC_LOG_ERROR("failed to find request for received reply with id %" PRIu64 "", error.reply_id());
783 0 : return;
784 : }
785 :
786 : // take the method call and remove from the map of outstanding calls
787 5 : MethodCall methodCall = it->second;
788 5 : m_methodCalls.erase(it);
789 :
790 : // update the timeout timer now a method call has been processed
791 5 : updateTimeoutTimer();
792 :
793 : // can now drop the lock
794 5 : locker.unlock();
795 :
796 5 : RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - %s", kSerialId, error.error_reason().c_str());
797 :
798 : // complete the call with an error
799 5 : completeWithError(&methodCall, error.error_reason());
800 : }
801 :
802 : // -----------------------------------------------------------------------------
803 : /*!
804 : \internal
805 :
806 :
807 : */
808 4 : void ChannelImpl::processEventFromServer(const transport::EventFromServer &event, std::vector<FileDescriptor> *fds)
809 : {
810 4 : RIALTO_IPC_LOG_DEBUG("processing event from server");
811 :
812 4 : const std::string &kEventName = event.event_name();
813 :
814 4 : std::lock_guard<std::mutex> locker(m_eventsLock);
815 :
816 4 : auto range = m_eventHandlers.equal_range(kEventName);
817 4 : if (range.first == range.second)
818 : {
819 0 : RIALTO_IPC_LOG_WARN("no handler for event %s", kEventName.c_str());
820 0 : return;
821 : }
822 :
823 4 : const google::protobuf::Descriptor *kDescriptor = range.first->second.descriptor;
824 :
825 : const google::protobuf::Message *kPrototype =
826 4 : google::protobuf::MessageFactory::generated_factory()->GetPrototype(kDescriptor);
827 4 : if (!kPrototype)
828 : {
829 0 : RIALTO_IPC_LOG_ERROR("failed to create prototype for event %s", kEventName.c_str());
830 0 : return;
831 : }
832 :
833 4 : std::shared_ptr<google::protobuf::Message> message(kPrototype->New());
834 4 : if (!message)
835 : {
836 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable message from prototype");
837 0 : return;
838 : }
839 :
840 4 : if (!message->ParseFromString(event.message()))
841 : {
842 0 : RIALTO_IPC_LOG_ERROR("failed to parse message for event %s", kEventName.c_str());
843 : }
844 4 : else if (!addReplyFileDescriptors(message.get(), fds))
845 : {
846 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the reply");
847 : }
848 : else
849 : {
850 4 : RIALTO_IPC_LOG_DEBUG("event{ %s } - %s { %s }", kEventName.c_str(), message->GetTypeName().c_str(),
851 : message->ShortDebugString().c_str());
852 :
853 8 : for (auto it = range.first; it != range.second; ++it)
854 : {
855 4 : it->second.handler(message);
856 : }
857 : }
858 : }
859 :
860 : // -----------------------------------------------------------------------------
861 : /*!
862 : \internal
863 : \static
864 :
865 : Reads all the file descriptors from a unix domain socket received \a msg.
866 : It returns the file descriptors as a vector of FileDescriptor objects,
867 : these objects safely store the fd and close them when they're destructed.
868 :
869 : The \a limit specifies the maximum number of fds to store, if more were
870 : sent then they are automatically closed and not returned in the vector.
871 :
872 : */
873 0 : std::vector<FileDescriptor> ChannelImpl::readMessageFds(const struct msghdr *msg, size_t limit)
874 : {
875 0 : std::vector<FileDescriptor> fds;
876 :
877 0 : for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr; cmsg = CMSG_NXTHDR((struct msghdr *)msg, cmsg))
878 : {
879 0 : if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
880 : {
881 0 : const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
882 0 : if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
883 : {
884 0 : RIALTO_IPC_LOG_ERROR("invalid fd array size");
885 : }
886 : else
887 : {
888 0 : const size_t n = kFdsLength / sizeof(int);
889 0 : RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
890 :
891 0 : fds.reserve(std::min(limit, n));
892 :
893 0 : const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
894 0 : for (size_t i = 0; i < n; i++)
895 : {
896 0 : RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
897 :
898 0 : if (fds.size() >= limit)
899 : {
900 0 : RIALTO_IPC_LOG_ERROR(
901 : "received to many file descriptors, exceeding max per message, closing left overs");
902 : }
903 : else
904 : {
905 0 : firebolt::rialto::ipc::FileDescriptor fd(kFds[i]);
906 0 : if (!fd.isValid())
907 : {
908 0 : RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
909 : }
910 : else
911 : {
912 0 : fds.emplace_back(std::move(fd));
913 : }
914 : }
915 :
916 0 : if (close(kFds[i]) != 0)
917 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
918 : }
919 : }
920 : }
921 : }
922 :
923 0 : return fds;
924 : }
925 :
926 : // -----------------------------------------------------------------------------
927 : /*!
928 : \static
929 : \internal
930 :
931 : Places the received file descriptors into the protobuf message.
932 :
933 : It works by iterating over the fields in the message, finding ones that are
934 : marked as 'field_is_fd' and then replacing the received integer value with
935 : an actual file descriptor.
936 :
937 : */
938 12 : bool ChannelImpl::addReplyFileDescriptors(google::protobuf::Message *reply,
939 : std::vector<firebolt::rialto::ipc::FileDescriptor> *fds)
940 : {
941 12 : auto fdIterator = fds->begin();
942 :
943 12 : const google::protobuf::Descriptor *kDescriptor = reply->GetDescriptor();
944 12 : const google::protobuf::Reflection *kReflection = nullptr;
945 :
946 12 : const int n = kDescriptor->field_count();
947 25 : for (int i = 0; i < n; i++)
948 : {
949 13 : auto fieldDescriptor = kDescriptor->field(i);
950 13 : if (fieldDescriptor->options().HasExtension(::firebolt::rialto::ipc::field_is_fd) &&
951 0 : fieldDescriptor->options().GetExtension(::firebolt::rialto::ipc::field_is_fd))
952 : {
953 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
954 : {
955 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
956 0 : return false;
957 : }
958 :
959 0 : if (!kReflection)
960 : {
961 0 : kReflection = reply->GetReflection();
962 : }
963 :
964 0 : if (kReflection->HasField(*reply, fieldDescriptor))
965 : {
966 0 : if (fdIterator == fds->end())
967 : {
968 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but none or too few were supplied");
969 0 : return false;
970 : }
971 :
972 0 : kReflection->SetInt32(reply, fieldDescriptor, fdIterator->fd());
973 0 : ++fdIterator;
974 : }
975 : }
976 : }
977 :
978 12 : if (fdIterator != fds->end())
979 : {
980 0 : RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
981 0 : return false;
982 : }
983 :
984 : // we now need to release all the fds stored in the vector, otherwise they
985 : // will be closed when the vector is destroyed. From now onwards it is the
986 : // caller's responsibility to close the fds in the returned protobuf message
987 : // object
988 12 : for (firebolt::rialto::ipc::FileDescriptor &fd : *fds)
989 : {
990 0 : fd.release();
991 : }
992 :
993 12 : return true;
994 : }
995 :
996 : // -----------------------------------------------------------------------------
997 : /*!
998 : \internal
999 : \static
1000 :
1001 :
1002 : */
1003 8 : void ChannelImpl::complete(MethodCall *call)
1004 : {
1005 8 : if (call->closure)
1006 : {
1007 8 : call->closure->Run();
1008 8 : call->closure = nullptr;
1009 : }
1010 : }
1011 :
1012 : // -----------------------------------------------------------------------------
1013 : /*!
1014 : \internal
1015 : \static
1016 :
1017 :
1018 : */
1019 6 : void ChannelImpl::completeWithError(MethodCall *call, std::string reason)
1020 : {
1021 6 : RIALTO_IPC_LOG_DEBUG("completing method call with error '%s'", reason.c_str());
1022 :
1023 6 : if (call->controller)
1024 : {
1025 6 : call->controller->setMethodCallFailed(std::move(reason));
1026 6 : call->controller = nullptr;
1027 : }
1028 :
1029 6 : if (call->closure)
1030 : {
1031 6 : call->closure->Run();
1032 6 : call->closure = nullptr;
1033 : }
1034 : }
1035 :
1036 : // -----------------------------------------------------------------------------
1037 : /*!
1038 : \static
1039 :
1040 : Iterates through the message and finds any file descriptor type fields, if
1041 : found it adds the fds to the returned vector.
1042 :
1043 : */
1044 15 : std::vector<int> ChannelImpl::getMessageFds(const google::protobuf::Message &message)
1045 : {
1046 15 : std::vector<int> fds;
1047 :
1048 15 : auto descriptor = message.GetDescriptor();
1049 15 : const int n = descriptor->field_count();
1050 47 : for (int i = 0; i < n; i++)
1051 : {
1052 32 : auto fieldDescriptor = descriptor->field(i);
1053 32 : if (fieldDescriptor->options().HasExtension(::firebolt::rialto::ipc::field_is_fd) &&
1054 0 : fieldDescriptor->options().GetExtension(::firebolt::rialto::ipc::field_is_fd))
1055 : {
1056 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
1057 : {
1058 0 : RIALTO_IPC_LOG_ERROR("field '%s' is marked as containing an fd but not an int32 type",
1059 : fieldDescriptor->full_name().c_str());
1060 : }
1061 : else
1062 : {
1063 0 : auto reflection = message.GetReflection();
1064 0 : int fd = reflection->GetInt32(message, fieldDescriptor);
1065 0 : fds.emplace_back(fd);
1066 : }
1067 : }
1068 : }
1069 :
1070 15 : return fds;
1071 : }
1072 :
1073 : // -----------------------------------------------------------------------------
1074 : /*!
1075 : \overload
1076 :
1077 : \quote
1078 : Call the given method of the remote service. The signature of this
1079 : procedure looks the same as Service::CallMethod(), but the requirements
1080 : are less strict in one important way: the request and response objects
1081 : need not be of any specific class as long as their descriptors are
1082 : method->input_type() and method->output_type().
1083 :
1084 : */
1085 15 : void ChannelImpl::CallMethod(const google::protobuf::MethodDescriptor *method, // NOLINT(build/function_format)
1086 : google::protobuf::RpcController *controller, const google::protobuf::Message *request,
1087 : google::protobuf::Message *response, google::protobuf::Closure *done)
1088 : {
1089 15 : MethodCall methodCall{std::chrono::steady_clock::now() + m_timeout,
1090 15 : dynamic_cast<ClientControllerImpl *>(controller), response, done};
1091 :
1092 : //
1093 15 : const uint64_t kSerialId = m_serialCounter++;
1094 :
1095 : // create the transport request
1096 15 : transport::MessageToServer message;
1097 15 : transport::MethodCall *call = message.mutable_call();
1098 15 : call->set_serial_id(kSerialId);
1099 15 : call->set_service_name(method->service()->full_name());
1100 15 : call->set_method_name(method->name());
1101 :
1102 : // copy in the actual message data
1103 15 : std::string reqString = request->SerializeAsString();
1104 15 : call->set_request_message(std::move(reqString));
1105 :
1106 15 : const size_t kRequiredDataLen = message.ByteSizeLong();
1107 15 : if (kRequiredDataLen > kMaxMessageSize)
1108 : {
1109 0 : RIALTO_IPC_LOG_ERROR("method call to big to send (%zu, max %zu", kRequiredDataLen, kMaxMessageSize);
1110 0 : completeWithError(&methodCall, "Method call to big");
1111 0 : return;
1112 : }
1113 :
1114 : // extract the fds from the message
1115 15 : const std::vector<int> kFds = getMessageFds(*request);
1116 15 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1117 :
1118 : // build the socket message to send
1119 : auto msgBuf =
1120 15 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
1121 :
1122 15 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1123 15 : bzero(header, sizeof(msghdr));
1124 :
1125 15 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1126 15 : header->msg_control = ctrl;
1127 15 : header->msg_controllen = kRequiredCtrlLen;
1128 :
1129 15 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1130 15 : header->msg_iov = iov;
1131 15 : header->msg_iovlen = 1;
1132 :
1133 15 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1134 15 : iov->iov_base = data;
1135 15 : iov->iov_len = kRequiredDataLen;
1136 :
1137 : // copy in the data
1138 15 : message.SerializeWithCachedSizesToArray(data);
1139 :
1140 : // next check if the request is sending any fd's
1141 15 : if (!kFds.empty())
1142 : {
1143 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1144 0 : if (!cmsg)
1145 : {
1146 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1147 0 : completeWithError(&methodCall, "Internal error");
1148 0 : return;
1149 : }
1150 :
1151 0 : cmsg->cmsg_level = SOL_SOCKET;
1152 0 : cmsg->cmsg_type = SCM_RIGHTS;
1153 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1154 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1155 0 : header->msg_controllen = cmsg->cmsg_len;
1156 : }
1157 :
1158 : // check if the method is expecting a reply
1159 16 : const bool kNoReplyExpected = method->options().HasExtension(::firebolt::rialto::ipc::no_reply) &&
1160 1 : method->options().GetExtension(::firebolt::rialto::ipc::no_reply);
1161 :
1162 : // finally, send the message
1163 15 : std::unique_lock<std::mutex> locker(m_lock);
1164 :
1165 15 : if (m_sock < 0)
1166 : {
1167 0 : locker.unlock();
1168 0 : completeWithError(&methodCall, "Not connected");
1169 : }
1170 15 : else if (sendmsg(m_sock, header, MSG_NOSIGNAL) != static_cast<ssize_t>(kRequiredDataLen))
1171 : {
1172 0 : locker.unlock();
1173 0 : completeWithError(&methodCall, "Failed to send message");
1174 : }
1175 : else
1176 : {
1177 15 : RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", kSerialId, call->service_name().c_str(),
1178 : call->method_name().c_str(), request->ShortDebugString().c_str());
1179 :
1180 15 : if (kNoReplyExpected)
1181 : {
1182 : // no reply from server is expected, however if the caller supplied
1183 : // a closure (it shouldn't) we should still call it now to indicate
1184 : // the method call has been made
1185 1 : if (done)
1186 1 : done->Run();
1187 : }
1188 : else
1189 : {
1190 : // add the message to the queue so we pick-up the reply
1191 14 : m_methodCalls.emplace(kSerialId, methodCall);
1192 :
1193 : // update the single timeout timer
1194 14 : updateTimeoutTimer();
1195 : }
1196 : }
1197 7 : }
1198 :
1199 46 : int ChannelImpl::subscribeImpl(const std::string &kEventName, const google::protobuf::Descriptor *descriptor,
1200 : EventHandler &&handler)
1201 : {
1202 46 : std::lock_guard<std::mutex> locker(m_eventsLock);
1203 :
1204 46 : const int kTag = m_eventTagCounter++;
1205 46 : m_eventHandlers.emplace(kEventName, Event{kTag, descriptor, std::move(handler)});
1206 :
1207 46 : return kTag;
1208 : }
1209 : }; // namespace firebolt::rialto::ipc
|