Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #include <algorithm>
21 : #include <cinttypes>
22 : #include <cstdarg>
23 : #include <memory>
24 : #include <utility>
25 :
26 : #include <fcntl.h>
27 : #include <poll.h>
28 : #include <sys/epoll.h>
29 : #include <sys/eventfd.h>
30 : #include <sys/socket.h>
31 : #include <sys/timerfd.h>
32 : #include <sys/un.h>
33 : #include <unistd.h>
34 :
35 : #include "IpcChannelImpl.h"
36 : #include "IpcLogging.h"
37 : #include "rialtoipc.pb.h"
38 :
39 : #if !defined(SCM_MAX_FD)
40 : #define SCM_MAX_FD 255
41 : #endif
42 :
43 : namespace
44 : {
45 : constexpr size_t kMaxMessageSize{128 * 1024};
46 : const std::chrono::milliseconds kDefaultIpcTimeout{3000};
47 :
48 33 : std::chrono::milliseconds getIpcTimeout()
49 : {
50 33 : const char *kCustomTimeout = getenv("RIALTO_CLIENT_IPC_TIMEOUT");
51 33 : std::chrono::milliseconds timeout{kDefaultIpcTimeout};
52 33 : if (kCustomTimeout)
53 : {
54 : try
55 : {
56 0 : timeout = std::chrono::milliseconds{std::stoull(kCustomTimeout)};
57 0 : RIALTO_IPC_LOG_INFO("Using custom Ipc timeout: %sms", kCustomTimeout);
58 : }
59 0 : catch (const std::exception &e)
60 : {
61 0 : RIALTO_IPC_LOG_ERROR("Custom Ipc timeout invalid, ignoring: %s", kCustomTimeout);
62 : }
63 : }
64 33 : return timeout;
65 : }
66 : } // namespace
67 :
68 : namespace firebolt::rialto::ipc
69 : {
70 :
71 33 : std::shared_ptr<IChannelFactory> IChannelFactory::createFactory()
72 : {
73 33 : std::shared_ptr<IChannelFactory> factory;
74 : try
75 : {
76 33 : factory = std::make_shared<ChannelFactory>();
77 : }
78 0 : catch (const std::exception &e)
79 : {
80 0 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel factory, reason: %s", e.what());
81 : }
82 :
83 33 : return factory;
84 : }
85 :
86 17 : std::shared_ptr<IChannel> ChannelFactory::createChannel(int sockFd)
87 : {
88 17 : std::shared_ptr<IChannel> channel;
89 : try
90 : {
91 17 : channel = std::make_shared<ChannelImpl>(sockFd);
92 : }
93 1 : catch (const std::exception &e)
94 : {
95 1 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel with socketFd %d, reason: %s", sockFd, e.what());
96 : }
97 :
98 17 : return channel;
99 : }
100 :
101 16 : std::shared_ptr<IChannel> ChannelFactory::createChannel(const std::string &socketPath)
102 : {
103 16 : std::shared_ptr<IChannel> channel;
104 : try
105 : {
106 16 : channel = std::make_shared<ChannelImpl>(socketPath);
107 : }
108 3 : catch (const std::exception &e)
109 : {
110 3 : RIALTO_IPC_LOG_ERROR("Failed to create the ipc channel with socketPath %s, reason: %s", socketPath.c_str(),
111 : e.what());
112 : }
113 :
114 16 : return channel;
115 : }
116 :
117 17 : ChannelImpl::ChannelImpl(int sock)
118 17 : : m_sock(-1), m_epollFd(-1), m_timerFd(-1), m_eventFd(-1), m_serialCounter(1), m_timeout(getIpcTimeout()),
119 34 : m_eventTagCounter(1)
120 : {
121 17 : if (!attachSocket(sock))
122 : {
123 1 : throw std::runtime_error("Failed attach the socket");
124 : }
125 16 : if (!initChannel())
126 : {
127 0 : termChannel();
128 0 : throw std::runtime_error("Failed to initalise the channel");
129 : }
130 16 : if (!isConnectedInternal())
131 : {
132 0 : termChannel();
133 0 : throw std::runtime_error("Channel not connected");
134 : }
135 20 : }
136 :
137 16 : ChannelImpl::ChannelImpl(const std::string &socketPath)
138 16 : : m_sock(-1), m_epollFd(-1), m_timerFd(-1), m_eventFd(-1), m_serialCounter(1), m_timeout(getIpcTimeout()),
139 32 : m_eventTagCounter(1)
140 : {
141 16 : if (!createConnectedSocket(socketPath))
142 : {
143 3 : throw std::runtime_error("Failed connect socket");
144 : }
145 13 : if (!initChannel())
146 : {
147 0 : termChannel();
148 0 : throw std::runtime_error("Failed to initalise the channel");
149 : }
150 13 : if (!isConnectedInternal())
151 : {
152 0 : termChannel();
153 0 : throw std::runtime_error("Channel not connected");
154 : }
155 25 : }
156 :
157 29 : ChannelImpl::~ChannelImpl()
158 : {
159 29 : termChannel();
160 : }
161 :
162 16 : bool ChannelImpl::createConnectedSocket(const std::string &socketPath)
163 : {
164 16 : int sock = socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
165 16 : if (sock < 0)
166 : {
167 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to create socket");
168 0 : return false;
169 : }
170 :
171 16 : struct sockaddr_un addr = {0};
172 16 : memset(&addr, 0x00, sizeof(addr));
173 16 : addr.sun_family = AF_UNIX;
174 16 : strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1);
175 :
176 16 : if (::connect(sock, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)) == -1)
177 : {
178 3 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to connect to %s", socketPath.c_str());
179 3 : close(sock);
180 3 : return false;
181 : }
182 :
183 13 : m_sock = sock;
184 :
185 13 : return true;
186 : }
187 :
188 17 : bool ChannelImpl::attachSocket(int sockFd)
189 : {
190 : // sanity check the supplied socket is of the right type
191 : struct sockaddr addr;
192 17 : socklen_t len = sizeof(addr);
193 17 : if ((getsockname(sockFd, &addr, &len) < 0) || (len < sizeof(sa_family_t)))
194 : {
195 1 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get name of supplied socket");
196 1 : return false;
197 : }
198 16 : if (addr.sa_family != AF_UNIX)
199 : {
200 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not a unix domain socket");
201 0 : return false;
202 : }
203 :
204 16 : int type = 0;
205 16 : len = sizeof(type);
206 16 : if ((getsockopt(sockFd, SOL_SOCKET, SO_TYPE, &type, &len) < 0) || (len != sizeof(type)))
207 : {
208 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to get type of supplied socket");
209 0 : return false;
210 : }
211 16 : if (type != SOCK_SEQPACKET)
212 : {
213 0 : RIALTO_IPC_LOG_ERROR("supplied client socket is not of type SOCK_SEQPACKET");
214 0 : return false;
215 : }
216 :
217 : // set the O_NONBLOCKING flag on the socket
218 16 : int flags = fcntl(sockFd, F_GETFL);
219 16 : if ((flags < 0) || (fcntl(sockFd, F_SETFL, flags | O_NONBLOCK) < 0))
220 : {
221 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to set socket to non-blocking mode");
222 0 : return false;
223 : }
224 :
225 16 : m_sock = sockFd;
226 :
227 16 : return true;
228 : }
229 :
230 29 : bool ChannelImpl::initChannel()
231 : {
232 : // create epoll so can listen for timeouts as well as socket messages
233 29 : m_epollFd = epoll_create1(EPOLL_CLOEXEC);
234 29 : if (m_epollFd < 0)
235 : {
236 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll failed");
237 0 : return false;
238 : }
239 :
240 : // add the socket to epoll
241 29 : epoll_event sockEvent = {.events = EPOLLIN, .data = {.fd = m_sock}};
242 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_sock, &sockEvent) != 0)
243 : {
244 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
245 0 : return false;
246 : }
247 :
248 29 : m_timerFd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC | TFD_NONBLOCK);
249 29 : if (m_timerFd < 0)
250 : {
251 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "timerfd_create failed");
252 0 : return false;
253 : }
254 :
255 : // add the timer event to epoll
256 29 : epoll_event timerEvent = {.events = EPOLLIN, .data = {.fd = m_timerFd}};
257 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_timerFd, &timerEvent) != 0)
258 : {
259 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
260 0 : return false;
261 : }
262 :
263 : // and lastly the eventfd to wake the poll loop
264 29 : m_eventFd = eventfd(0, EFD_CLOEXEC);
265 29 : if (m_eventFd < 0)
266 : {
267 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "eventfd create failed");
268 0 : return false;
269 : }
270 :
271 : // add the timer event to epoll
272 29 : epoll_event wakeEvent = {.events = EPOLLIN, .data = {.fd = m_eventFd}};
273 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_ADD, m_eventFd, &wakeEvent) != 0)
274 : {
275 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to add eventfd");
276 0 : return false;
277 : }
278 :
279 29 : return true;
280 : }
281 :
282 29 : void ChannelImpl::termChannel()
283 : {
284 : // close the socket and the epoll and timer fds
285 29 : if (m_sock >= 0)
286 0 : disconnectNoLock();
287 29 : if ((m_epollFd >= 0) && (close(m_epollFd) != 0))
288 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing epoll fd failed");
289 29 : if ((m_timerFd >= 0) && (close(m_timerFd) != 0))
290 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing timer fd failed");
291 29 : if ((m_eventFd >= 0) && (close(m_eventFd) != 0))
292 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "closing event fd failed");
293 :
294 : // if any method calls are still outstanding then complete them with errors now
295 29 : for (auto &entry : m_methodCalls)
296 : {
297 0 : completeWithError(&entry.second, "Channel destructed");
298 : }
299 :
300 29 : m_methodCalls.clear();
301 : }
302 :
303 29 : void ChannelImpl::disconnect()
304 : {
305 : // disconnect from the socket
306 : {
307 29 : std::lock_guard<std::mutex> locker(m_lock);
308 :
309 29 : if (m_sock < 0)
310 1 : return;
311 :
312 28 : disconnectNoLock();
313 16 : }
314 :
315 : // wake the wait(...) call so client code is blocked there it is woken
316 28 : if (m_eventFd >= 0)
317 : {
318 28 : uint64_t wakeup = 1;
319 28 : if (TEMP_FAILURE_RETRY(write(m_eventFd, &wakeup, sizeof(wakeup))) != sizeof(wakeup))
320 : {
321 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to wake event fd");
322 : }
323 : }
324 : }
325 :
326 : // -----------------------------------------------------------------------------
327 : /*!
328 : \internal
329 :
330 :
331 : */
332 29 : void ChannelImpl::disconnectNoLock()
333 : {
334 29 : if (m_sock < 0)
335 : {
336 0 : RIALTO_IPC_LOG_WARN("not connected\n");
337 0 : return;
338 : }
339 :
340 : // remove the socket from epoll
341 29 : if (epoll_ctl(m_epollFd, EPOLL_CTL_DEL, m_sock, nullptr) != 0)
342 : {
343 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_ctl failed to remove socket");
344 : }
345 :
346 : // shutdown and close the socket
347 29 : if (shutdown(m_sock, SHUT_RDWR) != 0)
348 : {
349 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "shutdown error");
350 : }
351 29 : if (close(m_sock) != 0)
352 : {
353 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "close error");
354 : }
355 :
356 29 : m_sock = -1;
357 : }
358 :
359 12727 : bool ChannelImpl::isConnected() const
360 : {
361 12727 : return isConnectedInternal();
362 : }
363 :
364 12756 : bool ChannelImpl::isConnectedInternal() const
365 : {
366 12756 : std::lock_guard<std::mutex> locker(m_lock);
367 12756 : return (m_sock >= 0);
368 : }
369 :
370 0 : int ChannelImpl::fd() const
371 : {
372 0 : return m_epollFd;
373 : }
374 :
375 3165 : bool ChannelImpl::wait(int timeoutMSecs)
376 : {
377 3165 : if ((m_epollFd < 0) || !isConnected())
378 : {
379 0 : return false;
380 : }
381 :
382 : // wait for any event (with timeout)
383 : struct pollfd fds[2];
384 3165 : fds[0].fd = m_epollFd;
385 3165 : fds[0].events = POLLIN;
386 :
387 3165 : int rc = TEMP_FAILURE_RETRY(poll(fds, 1, timeoutMSecs));
388 3165 : if (rc < 0)
389 : {
390 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "poll failed?");
391 0 : return false;
392 : }
393 :
394 3165 : return isConnected();
395 : }
396 :
397 3195 : bool ChannelImpl::process()
398 : {
399 3195 : if (!isConnected())
400 17 : return false;
401 :
402 : struct epoll_event events[3];
403 3178 : int rc = TEMP_FAILURE_RETRY(epoll_wait(m_epollFd, events, 3, 0));
404 3178 : if (rc < 0)
405 : {
406 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "epoll_wait failed");
407 0 : return false;
408 : }
409 :
410 : enum
411 : {
412 : HaveSocketEvent = 0x1,
413 : HaveTimeoutEvent = 0x2,
414 : HaveWakeEvent = 0x4
415 : };
416 3178 : unsigned eventsMask = 0;
417 3203 : for (int i = 0; i < rc; i++)
418 : {
419 25 : if (events[i].data.fd == m_sock)
420 24 : eventsMask |= HaveSocketEvent;
421 1 : else if (events[i].data.fd == m_timerFd)
422 1 : eventsMask |= HaveTimeoutEvent;
423 0 : else if (events[i].data.fd == m_eventFd)
424 0 : eventsMask |= HaveWakeEvent;
425 : }
426 :
427 3178 : if ((eventsMask & HaveSocketEvent) && !processSocketEvent())
428 : {
429 1 : std::map<uint64_t, MethodCall> callsToDelete;
430 : {
431 1 : std::lock_guard<std::mutex> locker(m_lock);
432 1 : callsToDelete = m_methodCalls;
433 1 : m_methodCalls.clear();
434 : }
435 :
436 1 : for (auto &entry : callsToDelete)
437 : {
438 0 : completeWithError(&entry.second, "Socket down");
439 : }
440 1 : return false;
441 : }
442 :
443 3177 : if (eventsMask & HaveTimeoutEvent)
444 1 : processTimeoutEvent();
445 :
446 3177 : if (eventsMask & HaveWakeEvent)
447 0 : processWakeEvent();
448 :
449 3177 : return isConnected();
450 : }
451 :
452 24 : bool ChannelImpl::unsubscribe(int eventTag)
453 : {
454 24 : std::lock_guard<std::mutex> locker(m_eventsLock);
455 24 : bool success = false;
456 :
457 24 : auto it = std::find_if(m_eventHandlers.begin(), m_eventHandlers.end(),
458 36 : [&](const auto &item) { return item.second.id == eventTag; });
459 24 : if (m_eventHandlers.end() != it)
460 : {
461 24 : m_eventHandlers.erase(it);
462 24 : success = true;
463 : }
464 :
465 24 : return success;
466 : }
467 :
468 : // -----------------------------------------------------------------------------
469 : /*!
470 : \internal
471 :
472 :
473 : */
474 24 : bool ChannelImpl::processSocketEvent()
475 : {
476 : static std::mutex bufLock;
477 24 : std::lock_guard<std::mutex> bufLocker(bufLock);
478 :
479 28 : static std::vector<uint8_t> dataBuf(kMaxMessageSize);
480 28 : static std::vector<uint8_t> ctrlBuf(CMSG_SPACE(SCM_MAX_FD * sizeof(int)));
481 :
482 : // read all messages from the client socket, we break out if the socket is closed
483 : // or EWOULDBLOCK is returned on a read (ie. no more messages to read)
484 : while (true)
485 : {
486 47 : struct msghdr msg = {nullptr};
487 47 : struct iovec io = {.iov_base = dataBuf.data(), .iov_len = dataBuf.size()};
488 :
489 47 : bzero(&msg, sizeof(msg));
490 47 : msg.msg_iov = &io;
491 47 : msg.msg_iovlen = 1;
492 47 : msg.msg_control = ctrlBuf.data();
493 47 : msg.msg_controllen = ctrlBuf.size();
494 :
495 : // read one message
496 47 : ssize_t rd = TEMP_FAILURE_RETRY(recvmsg(m_sock, &msg, MSG_CMSG_CLOEXEC));
497 47 : if (rd < 0)
498 : {
499 23 : if (errno != EWOULDBLOCK)
500 : {
501 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading client socket");
502 :
503 0 : std::lock_guard<std::mutex> locker(m_lock);
504 0 : disconnectNoLock();
505 0 : return false;
506 : }
507 :
508 23 : break;
509 : }
510 24 : else if (rd == 0)
511 : {
512 : // server closed connection, and we've read all data
513 1 : RIALTO_IPC_LOG_INFO("socket remote end closed, disconnecting channel");
514 :
515 1 : std::lock_guard<std::mutex> locker(m_lock);
516 1 : disconnectNoLock();
517 1 : return false;
518 : }
519 23 : else if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))
520 : {
521 0 : RIALTO_IPC_LOG_WARN("received truncated message from server, discarding");
522 :
523 : // make sure to close all the fds, otherwise we'll leak them, this
524 : // will read the fds and return in a vector, which will then be
525 : // destroyed, closing all the fds
526 0 : readMessageFds(&msg, 16);
527 : }
528 : else
529 : {
530 : // if there is control data then assume fd(s) have been passed
531 23 : std::vector<FileDescriptor> fds;
532 23 : if (msg.msg_controllen > 0)
533 : {
534 0 : fds = readMessageFds(&msg, 32);
535 : }
536 :
537 : // process the message from the server
538 23 : processServerMessage(dataBuf.data(), rd, &fds);
539 : }
540 : }
541 :
542 23 : return true;
543 13 : }
544 :
545 : // -----------------------------------------------------------------------------
546 : /*!
547 : \internal
548 :
549 : Called from process() to check if the timerfd has expired and if so cancel
550 : the any outstanding method calls that have now timed-out.
551 :
552 : */
553 1 : void ChannelImpl::processTimeoutEvent()
554 : {
555 : // read the timerfd to clear any expirations
556 : uint64_t expirations;
557 1 : ssize_t rd = TEMP_FAILURE_RETRY(read(m_timerFd, &expirations, sizeof(expirations)));
558 1 : if (rd < 0)
559 : {
560 0 : if (errno != EWOULDBLOCK)
561 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "error reading timerfd");
562 0 : return;
563 : }
564 :
565 : // check if any method call has no expired
566 1 : std::unique_lock<std::mutex> locker(m_lock);
567 :
568 : // stores the timed-out method calls
569 1 : std::vector<MethodCall> timedOuts;
570 :
571 : // remove the method calls that have expired
572 1 : const auto kNow = std::chrono::steady_clock::now();
573 1 : auto it = m_methodCalls.begin();
574 2 : while (it != m_methodCalls.end())
575 : {
576 1 : if (kNow >= it->second.timeoutDeadline)
577 : {
578 1 : timedOuts.emplace_back(it->second);
579 1 : it = m_methodCalls.erase(it);
580 : }
581 : else
582 : {
583 0 : ++it;
584 : }
585 : }
586 :
587 : // if we still have method calls available, then re-calculate the timer for the next timeout
588 1 : if (!m_methodCalls.empty())
589 : {
590 0 : updateTimeoutTimer();
591 : }
592 :
593 : // drop the lock and now terminate the timed out method calls
594 1 : locker.unlock();
595 :
596 2 : for (auto &call : timedOuts)
597 : {
598 2 : completeWithError(&call, "Timed out");
599 : }
600 1 : }
601 :
602 : // -----------------------------------------------------------------------------
603 : /*!
604 : \internal
605 :
606 : Called from process() when the eventfd was used to wake the event loop. All
607 : the function does is read the eventfd to clear it's value.
608 :
609 : */
610 0 : void ChannelImpl::processWakeEvent()
611 : {
612 : uint64_t ignore;
613 0 : if (TEMP_FAILURE_RETRY(read(m_eventFd, &ignore, sizeof(ignore))) != sizeof(ignore))
614 : {
615 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "Failed to read wake eventfd to clear it");
616 : }
617 : }
618 :
619 : // -----------------------------------------------------------------------------
620 : /*!
621 : \internal
622 :
623 : Updates the timerfd to the time of the next method call timeout. If method
624 : calls are pending then the timer is disabled.
625 :
626 : This should be called whenever a new method is called or a method has
627 : completed.
628 :
629 : \note Must be called while holding the m_lock mutex.
630 :
631 : */
632 39 : void ChannelImpl::updateTimeoutTimer()
633 : {
634 39 : struct itimerspec ts = {{0}};
635 :
636 : // if no method calls then just disarm the timer
637 39 : if (!m_methodCalls.empty())
638 : {
639 : // otherwise, find the next soonest timeout
640 20 : std::chrono::steady_clock::time_point nextTimeout = std::chrono::steady_clock::time_point::max();
641 : auto nextTimeoutCall =
642 20 : std::min_element(m_methodCalls.begin(), m_methodCalls.end(), [](const auto &elem, const auto ¤tMin)
643 0 : { return elem.second.timeoutDeadline < currentMin.second.timeoutDeadline; });
644 20 : if (nextTimeoutCall != m_methodCalls.end())
645 : {
646 20 : nextTimeout = nextTimeoutCall->second.timeoutDeadline;
647 : }
648 :
649 : // set the timerfd to the next duration
650 : const std::chrono::microseconds kDuration =
651 20 : std::chrono::duration_cast<std::chrono::microseconds>(nextTimeout - std::chrono::steady_clock::now());
652 20 : if (kDuration <= std::chrono::microseconds::zero())
653 : {
654 0 : ts.it_value.tv_nsec = 1000;
655 : }
656 : else
657 : {
658 20 : ts.it_value.tv_sec = static_cast<time_t>(std::chrono::duration_cast<std::chrono::seconds>(kDuration).count());
659 20 : ts.it_value.tv_nsec = static_cast<int32_t>((kDuration.count() % 1000000) * 1000);
660 : }
661 :
662 20 : RIALTO_IPC_LOG_DEBUG("next timeout in %" PRId64 "us - %ld.%09lds", kDuration.count(), ts.it_value.tv_sec,
663 : ts.it_value.tv_nsec);
664 : }
665 :
666 : // write the timeout value
667 39 : if (timerfd_settime(m_timerFd, 0, &ts, nullptr) != 0)
668 : {
669 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to write to timerfd");
670 : }
671 39 : }
672 :
673 : // -----------------------------------------------------------------------------
674 : /*!
675 : \internal
676 :
677 : Processes a single message from the server, it may be a method call response
678 : or an event.
679 :
680 :
681 : */
682 23 : void ChannelImpl::processServerMessage(const uint8_t *data, size_t dataLen, std::vector<FileDescriptor> *fds)
683 : {
684 : // parse the message
685 23 : transport::MessageFromServer message;
686 23 : if (!message.ParseFromArray(data, static_cast<int>(dataLen)))
687 : {
688 0 : RIALTO_IPC_LOG_ERROR("invalid message from server");
689 0 : return;
690 : }
691 :
692 : // check if an event or a reply to a request
693 23 : if (message.has_reply())
694 : {
695 13 : processReplyFromServer(message.reply(), fds);
696 : }
697 10 : else if (message.has_error())
698 : {
699 6 : processErrorFromServer(message.error());
700 : }
701 4 : else if (message.has_event())
702 : {
703 4 : processEventFromServer(message.event(), fds);
704 : }
705 : else
706 : {
707 0 : RIALTO_IPC_LOG_ERROR("message from server is missing reply or event type");
708 : }
709 23 : }
710 :
711 : // -----------------------------------------------------------------------------
712 : /*!
713 : \internal
714 :
715 :
716 : */
717 13 : void ChannelImpl::processReplyFromServer(const transport::MethodCallReply &reply, std::vector<FileDescriptor> *fds)
718 : {
719 13 : RIALTO_IPC_LOG_DEBUG("processing reply from server");
720 :
721 13 : std::unique_lock<std::mutex> locker(m_lock);
722 :
723 : // find the original request
724 13 : const uint64_t kSerialId = reply.reply_id();
725 13 : auto it = m_methodCalls.find(kSerialId);
726 13 : if (it == m_methodCalls.end())
727 : {
728 0 : RIALTO_IPC_LOG_ERROR("failed to find request for received reply with id %" PRIu64 "", reply.reply_id());
729 0 : return;
730 : }
731 :
732 : // take the method call and remove from the map of outstanding calls
733 13 : MethodCall methodCall = it->second;
734 13 : m_methodCalls.erase(it);
735 :
736 : // update the timeout timer now a method call has been processed
737 13 : updateTimeoutTimer();
738 :
739 : // can now drop the lock
740 13 : locker.unlock();
741 :
742 : // this is an actual reply so try and read it
743 13 : if (!methodCall.response->ParseFromString(reply.reply_message()))
744 : {
745 0 : RIALTO_IPC_LOG_ERROR("failed to parse method reply from server");
746 0 : completeWithError(&methodCall, "Failed to parse reply message");
747 : }
748 13 : else if (!addReplyFileDescriptors(methodCall.response, fds))
749 : {
750 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the reply");
751 0 : completeWithError(&methodCall, "Mismatched file descriptors in message");
752 : }
753 13 : else if (methodCall.closure)
754 : {
755 13 : RIALTO_IPC_LOG_DEBUG("reply{ serial %" PRIu64 " } - %s { %s }", kSerialId,
756 : methodCall.response->GetTypeName().c_str(), methodCall.response->ShortDebugString().c_str());
757 :
758 13 : complete(&methodCall);
759 : }
760 : }
761 :
762 : // -----------------------------------------------------------------------------
763 : /*!
764 : \internal
765 :
766 :
767 : */
768 6 : void ChannelImpl::processErrorFromServer(const transport::MethodCallError &error)
769 : {
770 6 : RIALTO_IPC_LOG_DEBUG("processing error from server");
771 :
772 6 : std::unique_lock<std::mutex> locker(m_lock);
773 :
774 : // find the original request
775 6 : const uint64_t kSerialId = error.reply_id();
776 6 : auto it = m_methodCalls.find(kSerialId);
777 6 : if (it == m_methodCalls.end())
778 : {
779 0 : RIALTO_IPC_LOG_ERROR("failed to find request for received reply with id %" PRIu64 "", error.reply_id());
780 0 : return;
781 : }
782 :
783 : // take the method call and remove from the map of outstanding calls
784 6 : MethodCall methodCall = it->second;
785 6 : m_methodCalls.erase(it);
786 :
787 : // update the timeout timer now a method call has been processed
788 6 : updateTimeoutTimer();
789 :
790 : // can now drop the lock
791 6 : locker.unlock();
792 :
793 6 : RIALTO_IPC_LOG_DEBUG("error{ serial %" PRIu64 " } - %s", kSerialId, error.error_reason().c_str());
794 :
795 : // complete the call with an error
796 6 : completeWithError(&methodCall, error.error_reason());
797 : }
798 :
799 : // -----------------------------------------------------------------------------
800 : /*!
801 : \internal
802 :
803 :
804 : */
805 4 : void ChannelImpl::processEventFromServer(const transport::EventFromServer &event, std::vector<FileDescriptor> *fds)
806 : {
807 4 : RIALTO_IPC_LOG_DEBUG("processing event from server");
808 :
809 4 : const std::string &kEventName = event.event_name();
810 :
811 4 : std::lock_guard<std::mutex> locker(m_eventsLock);
812 :
813 4 : auto range = m_eventHandlers.equal_range(kEventName);
814 4 : if (range.first == range.second)
815 : {
816 0 : RIALTO_IPC_LOG_WARN("no handler for event %s", kEventName.c_str());
817 0 : return;
818 : }
819 :
820 4 : const google::protobuf::Descriptor *kDescriptor = range.first->second.descriptor;
821 :
822 : const google::protobuf::Message *kPrototype =
823 4 : google::protobuf::MessageFactory::generated_factory()->GetPrototype(kDescriptor);
824 4 : if (!kPrototype)
825 : {
826 0 : RIALTO_IPC_LOG_ERROR("failed to create prototype for event %s", kEventName.c_str());
827 0 : return;
828 : }
829 :
830 4 : std::shared_ptr<google::protobuf::Message> message(kPrototype->New());
831 4 : if (!message)
832 : {
833 0 : RIALTO_IPC_LOG_ERROR("failed to create mutable message from prototype");
834 0 : return;
835 : }
836 :
837 4 : if (!message->ParseFromString(event.message()))
838 : {
839 0 : RIALTO_IPC_LOG_ERROR("failed to parse message for event %s", kEventName.c_str());
840 : }
841 4 : else if (!addReplyFileDescriptors(message.get(), fds))
842 : {
843 0 : RIALTO_IPC_LOG_ERROR("mismatch of file descriptors to the reply");
844 : }
845 : else
846 : {
847 4 : RIALTO_IPC_LOG_DEBUG("event{ %s } - %s { %s }", kEventName.c_str(), message->GetTypeName().c_str(),
848 : message->ShortDebugString().c_str());
849 :
850 8 : for (auto it = range.first; it != range.second; ++it)
851 : {
852 4 : it->second.handler(message);
853 : }
854 : }
855 : }
856 :
857 : // -----------------------------------------------------------------------------
858 : /*!
859 : \internal
860 : \static
861 :
862 : Reads all the file descriptors from a unix domain socket received \a msg.
863 : It returns the file descriptors as a vector of FileDescriptor objects,
864 : these objects safely store the fd and close them when they're destructed.
865 :
866 : The \a limit specifies the maximum number of fds to store, if more were
867 : sent then they are automatically closed and not returned in the vector.
868 :
869 : */
870 0 : std::vector<FileDescriptor> ChannelImpl::readMessageFds(const struct msghdr *msg, size_t limit)
871 : {
872 0 : std::vector<FileDescriptor> fds;
873 :
874 0 : for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr;
875 0 : cmsg = CMSG_NXTHDR(const_cast<struct msghdr *>(msg), cmsg))
876 : {
877 0 : if ((cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS))
878 : {
879 0 : const unsigned kFdsLength = cmsg->cmsg_len - CMSG_LEN(0);
880 0 : if ((kFdsLength < sizeof(int)) || ((kFdsLength % sizeof(int)) != 0))
881 : {
882 0 : RIALTO_IPC_LOG_ERROR("invalid fd array size");
883 : }
884 : else
885 : {
886 0 : const size_t n = kFdsLength / sizeof(int);
887 0 : RIALTO_IPC_LOG_DEBUG("received %zu fds", n);
888 :
889 0 : fds.reserve(std::min(limit, n));
890 :
891 0 : const int *kFds = reinterpret_cast<int *>(CMSG_DATA(cmsg));
892 0 : for (size_t i = 0; i < n; i++)
893 : {
894 0 : RIALTO_IPC_LOG_DEBUG("received fd %d", kFds[i]);
895 :
896 0 : if (fds.size() >= limit)
897 : {
898 0 : RIALTO_IPC_LOG_ERROR(
899 : "received to many file descriptors, exceeding max per message, closing left overs");
900 : }
901 : else
902 : {
903 0 : firebolt::rialto::ipc::FileDescriptor fileDescriptor(kFds[i]);
904 0 : if (!fileDescriptor.isValid())
905 : {
906 0 : RIALTO_IPC_LOG_ERROR("received invalid fd (couldn't dup)");
907 : }
908 : else
909 : {
910 0 : fds.emplace_back(std::move(fileDescriptor));
911 : }
912 : }
913 :
914 0 : if (close(kFds[i]) != 0)
915 0 : RIALTO_IPC_LOG_SYS_ERROR(errno, "failed to close received fd");
916 : }
917 : }
918 : }
919 : }
920 :
921 0 : return fds;
922 : }
923 :
924 : // -----------------------------------------------------------------------------
925 : /*!
926 : \static
927 : \internal
928 :
929 : Places the received file descriptors into the protobuf message.
930 :
931 : It works by iterating over the fields in the message, finding ones that are
932 : marked as 'field_is_fd' and then replacing the received integer value with
933 : an actual file descriptor.
934 :
935 : */
936 17 : bool ChannelImpl::addReplyFileDescriptors(google::protobuf::Message *reply,
937 : std::vector<firebolt::rialto::ipc::FileDescriptor> *fds)
938 : {
939 17 : auto fdIterator = fds->begin();
940 :
941 17 : const google::protobuf::Descriptor *kDescriptor = reply->GetDescriptor();
942 17 : const google::protobuf::Reflection *kReflection = nullptr;
943 :
944 17 : const int n = kDescriptor->field_count();
945 34 : for (int i = 0; i < n; i++)
946 : {
947 17 : auto fieldDescriptor = kDescriptor->field(i);
948 17 : if (fieldDescriptor->options().HasExtension(::firebolt::rialto::ipc::field_is_fd) &&
949 0 : fieldDescriptor->options().GetExtension(::firebolt::rialto::ipc::field_is_fd))
950 : {
951 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
952 : {
953 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but not an int32 type");
954 0 : return false;
955 : }
956 :
957 0 : if (!kReflection)
958 : {
959 0 : kReflection = reply->GetReflection();
960 : }
961 :
962 0 : if (kReflection->HasField(*reply, fieldDescriptor))
963 : {
964 0 : if (fdIterator == fds->end())
965 : {
966 0 : RIALTO_IPC_LOG_ERROR("field is marked as containing an fd but none or too few were supplied");
967 0 : return false;
968 : }
969 :
970 0 : kReflection->SetInt32(reply, fieldDescriptor, fdIterator->fd());
971 0 : ++fdIterator;
972 : }
973 : }
974 : }
975 :
976 17 : if (fdIterator != fds->end())
977 : {
978 0 : RIALTO_IPC_LOG_ERROR("received too many file descriptors in the message");
979 0 : return false;
980 : }
981 :
982 : // we now need to release all the fds stored in the vector, otherwise they
983 : // will be closed when the vector is destroyed. From now onwards it is the
984 : // caller's responsibility to close the fds in the returned protobuf message
985 : // object
986 17 : for (firebolt::rialto::ipc::FileDescriptor &fd : *fds)
987 : {
988 0 : fd.release();
989 : }
990 :
991 17 : return true;
992 : }
993 :
994 : // -----------------------------------------------------------------------------
995 : /*!
996 : \internal
997 : \static
998 :
999 :
1000 : */
1001 13 : void ChannelImpl::complete(MethodCall *call)
1002 : {
1003 13 : if (call->closure)
1004 : {
1005 13 : call->closure->Run();
1006 13 : call->closure = nullptr;
1007 : }
1008 : }
1009 :
1010 : // -----------------------------------------------------------------------------
1011 : /*!
1012 : \internal
1013 : \static
1014 :
1015 :
1016 : */
1017 7 : void ChannelImpl::completeWithError(MethodCall *call, std::string reason)
1018 : {
1019 7 : RIALTO_IPC_LOG_DEBUG("completing method call with error '%s'", reason.c_str());
1020 :
1021 7 : if (call->controller)
1022 : {
1023 7 : call->controller->setMethodCallFailed(std::move(reason));
1024 7 : call->controller = nullptr;
1025 : }
1026 :
1027 7 : if (call->closure)
1028 : {
1029 7 : call->closure->Run();
1030 7 : call->closure = nullptr;
1031 : }
1032 : }
1033 :
1034 : // -----------------------------------------------------------------------------
1035 : /*!
1036 : \static
1037 :
1038 : Iterates through the message and finds any file descriptor type fields, if
1039 : found it adds the fds to the returned vector.
1040 :
1041 : */
1042 21 : std::vector<int> ChannelImpl::getMessageFds(const google::protobuf::Message &message)
1043 : {
1044 21 : std::vector<int> fds;
1045 :
1046 21 : auto descriptor = message.GetDescriptor();
1047 21 : const int n = descriptor->field_count();
1048 75 : for (int i = 0; i < n; i++)
1049 : {
1050 54 : auto fieldDescriptor = descriptor->field(i);
1051 54 : if (fieldDescriptor->options().HasExtension(::firebolt::rialto::ipc::field_is_fd) &&
1052 0 : fieldDescriptor->options().GetExtension(::firebolt::rialto::ipc::field_is_fd))
1053 : {
1054 0 : if (fieldDescriptor->type() != google::protobuf::FieldDescriptor::TYPE_INT32)
1055 : {
1056 0 : RIALTO_IPC_LOG_ERROR("field '%s' is marked as containing an fd but not an int32 type",
1057 : fieldDescriptor->full_name().c_str());
1058 : }
1059 : else
1060 : {
1061 0 : auto reflection = message.GetReflection();
1062 0 : int fileDescriptor = reflection->GetInt32(message, fieldDescriptor);
1063 0 : fds.emplace_back(fileDescriptor);
1064 : }
1065 : }
1066 : }
1067 :
1068 21 : return fds;
1069 : }
1070 :
1071 : // -----------------------------------------------------------------------------
1072 : /*!
1073 : \overload
1074 :
1075 : \quote
1076 : Call the given method of the remote service. The signature of this
1077 : procedure looks the same as Service::CallMethod(), but the requirements
1078 : are less strict in one important way: the request and response objects
1079 : need not be of any specific class as long as their descriptors are
1080 : method->input_type() and method->output_type().
1081 :
1082 : */
1083 21 : void ChannelImpl::CallMethod(const google::protobuf::MethodDescriptor *method, // NOLINT(build/function_format)
1084 : google::protobuf::RpcController *controller, const google::protobuf::Message *request,
1085 : google::protobuf::Message *response, google::protobuf::Closure *done)
1086 : {
1087 21 : MethodCall methodCall{std::chrono::steady_clock::now() + m_timeout,
1088 21 : dynamic_cast<ClientControllerImpl *>(controller), response, done};
1089 :
1090 : //
1091 21 : const uint64_t kSerialId = m_serialCounter++;
1092 :
1093 : // create the transport request
1094 21 : transport::MessageToServer message;
1095 21 : transport::MethodCall *call = message.mutable_call();
1096 21 : call->set_serial_id(kSerialId);
1097 21 : call->set_service_name(method->service()->full_name());
1098 21 : call->set_method_name(method->name());
1099 :
1100 : // copy in the actual message data
1101 21 : std::string reqString = request->SerializeAsString();
1102 21 : call->set_request_message(std::move(reqString));
1103 :
1104 21 : const size_t kRequiredDataLen = message.ByteSizeLong();
1105 21 : if (kRequiredDataLen > kMaxMessageSize)
1106 : {
1107 0 : RIALTO_IPC_LOG_ERROR("method call to big to send (%zu, max %zu", kRequiredDataLen, kMaxMessageSize);
1108 0 : completeWithError(&methodCall, "Method call to big");
1109 0 : return;
1110 : }
1111 :
1112 : // extract the fds from the message
1113 21 : const std::vector<int> kFds = getMessageFds(*request);
1114 21 : const size_t kRequiredCtrlLen = kFds.empty() ? 0 : CMSG_SPACE(sizeof(int) * kFds.size());
1115 :
1116 : // build the socket message to send
1117 : auto msgBuf =
1118 21 : m_sendBufPool.allocateShared<uint8_t>(sizeof(msghdr) + sizeof(iovec) + kRequiredCtrlLen + kRequiredDataLen);
1119 :
1120 21 : auto *header = reinterpret_cast<msghdr *>(msgBuf.get());
1121 21 : bzero(header, sizeof(msghdr));
1122 :
1123 21 : auto *ctrl = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr));
1124 21 : header->msg_control = ctrl;
1125 21 : header->msg_controllen = kRequiredCtrlLen;
1126 :
1127 21 : auto *iov = reinterpret_cast<iovec *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen);
1128 21 : header->msg_iov = iov;
1129 21 : header->msg_iovlen = 1;
1130 :
1131 21 : auto *data = reinterpret_cast<uint8_t *>(msgBuf.get() + sizeof(msghdr) + kRequiredCtrlLen + sizeof(iovec));
1132 21 : iov->iov_base = data;
1133 21 : iov->iov_len = kRequiredDataLen;
1134 :
1135 : // copy in the data
1136 21 : message.SerializeWithCachedSizesToArray(data);
1137 :
1138 : // next check if the request is sending any fd's
1139 21 : if (!kFds.empty())
1140 : {
1141 0 : struct cmsghdr *cmsg = CMSG_FIRSTHDR(header);
1142 0 : if (!cmsg)
1143 : {
1144 0 : RIALTO_IPC_LOG_ERROR("odd, failed to get the first cmsg header");
1145 0 : completeWithError(&methodCall, "Internal error");
1146 0 : return;
1147 : }
1148 :
1149 0 : cmsg->cmsg_level = SOL_SOCKET;
1150 0 : cmsg->cmsg_type = SCM_RIGHTS;
1151 0 : cmsg->cmsg_len = CMSG_LEN(sizeof(int) * kFds.size());
1152 0 : memcpy(CMSG_DATA(cmsg), kFds.data(), sizeof(int) * kFds.size());
1153 0 : header->msg_controllen = cmsg->cmsg_len;
1154 : }
1155 :
1156 : // check if the method is expecting a reply
1157 22 : const bool kNoReplyExpected = method->options().HasExtension(::firebolt::rialto::ipc::no_reply) &&
1158 1 : method->options().GetExtension(::firebolt::rialto::ipc::no_reply);
1159 :
1160 : // finally, send the message
1161 21 : std::unique_lock<std::mutex> locker(m_lock);
1162 :
1163 21 : if (m_sock < 0)
1164 : {
1165 0 : locker.unlock();
1166 0 : completeWithError(&methodCall, "Not connected");
1167 : }
1168 21 : else if (sendmsg(m_sock, header, MSG_NOSIGNAL) != static_cast<ssize_t>(kRequiredDataLen))
1169 : {
1170 0 : locker.unlock();
1171 0 : completeWithError(&methodCall, "Failed to send message");
1172 : }
1173 : else
1174 : {
1175 21 : RIALTO_IPC_LOG_DEBUG("call{ serial %" PRIu64 " } - %s.%s { %s }", kSerialId, call->service_name().c_str(),
1176 : call->method_name().c_str(), request->ShortDebugString().c_str());
1177 :
1178 21 : if (kNoReplyExpected)
1179 : {
1180 : // no reply from server is expected, however if the caller supplied
1181 : // a closure (it shouldn't) we should still call it now to indicate
1182 : // the method call has been made
1183 1 : if (done)
1184 1 : done->Run();
1185 : }
1186 : else
1187 : {
1188 : // add the message to the queue so we pick-up the reply
1189 20 : m_methodCalls.emplace(kSerialId, methodCall);
1190 :
1191 : // update the single timeout timer
1192 20 : updateTimeoutTimer();
1193 : }
1194 : }
1195 11 : }
1196 :
1197 58 : int ChannelImpl::subscribeImpl(const std::string &kEventName, const google::protobuf::Descriptor *descriptor,
1198 : EventHandler &&handler)
1199 : {
1200 58 : std::lock_guard<std::mutex> locker(m_eventsLock);
1201 :
1202 58 : const int kTag = m_eventTagCounter++;
1203 58 : m_eventHandlers.emplace(kEventName, Event{kTag, descriptor, std::move(handler)});
1204 :
1205 58 : return kTag;
1206 : }
1207 : }; // namespace firebolt::rialto::ipc
|