Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #ifndef FIREBOLT_RIALTO_IPC_IPC_CHANNEL_IMPL_H_
21 : #define FIREBOLT_RIALTO_IPC_IPC_CHANNEL_IMPL_H_
22 :
23 : #include "FileDescriptor.h"
24 : #include "IIpcChannel.h"
25 : #include "IpcClientControllerImpl.h"
26 : #include "SimpleBufferPool.h"
27 :
28 : #include "rialtoipc-transport.pb.h"
29 : #include <google/protobuf/service.h>
30 :
31 : #include <atomic>
32 : #include <chrono>
33 : #include <functional>
34 : #include <map>
35 : #include <memory>
36 : #include <mutex>
37 : #include <string>
38 : #include <vector>
39 :
40 : #include <sys/socket.h>
41 :
42 : namespace firebolt::rialto::ipc
43 : {
44 : class ChannelFactory : public IChannelFactory
45 : {
46 : public:
47 33 : ChannelFactory() = default;
48 33 : ~ChannelFactory() override = default;
49 :
50 : std::shared_ptr<IChannel> createChannel(int sockFd) override;
51 : std::shared_ptr<IChannel> createChannel(const std::string &socketPath) override;
52 : };
53 :
54 : class ChannelImpl final : public IChannel
55 : {
56 : public:
57 : explicit ChannelImpl(int sockFd);
58 : explicit ChannelImpl(const std::string &socketPath);
59 : ~ChannelImpl() final;
60 :
61 : void disconnect() override;
62 : bool isConnected() const override;
63 :
64 : int fd() const override;
65 : bool wait(int timeoutMSecs) override;
66 : bool process() override;
67 : bool unsubscribe(int eventTag) override;
68 :
69 : void CallMethod(const google::protobuf::MethodDescriptor *method, google::protobuf::RpcController *controller,
70 : const google::protobuf::Message *request, google::protobuf::Message *response,
71 : google::protobuf::Closure *done) override;
72 :
73 : using EventHandler = std::function<void(const std::shared_ptr<google::protobuf::Message> &msg)>;
74 :
75 : int subscribeImpl(const std::string &name, const google::protobuf::Descriptor *descriptor,
76 : EventHandler &&handler) override;
77 :
78 : private:
79 : void disconnectNoLock();
80 :
81 : bool processSocketEvent();
82 : void processTimeoutEvent();
83 : void processWakeEvent();
84 :
85 : void processServerMessage(const uint8_t *data, size_t len, std::vector<FileDescriptor> *fds);
86 : void processReplyFromServer(const ::firebolt::rialto::ipc::transport::MethodCallReply &reply,
87 : std::vector<FileDescriptor> *fds);
88 : void processErrorFromServer(const ::firebolt::rialto::ipc::transport::MethodCallError &error);
89 : void processEventFromServer(const ::firebolt::rialto::ipc::transport::EventFromServer &event,
90 : std::vector<FileDescriptor> *fds);
91 :
92 : bool createConnectedSocket(const std::string &socketPath);
93 : bool attachSocket(int sockFd);
94 : bool initChannel();
95 : void termChannel();
96 : bool isConnectedInternal() const; // to avoid calling virtual method in constructor
97 :
98 : static std::vector<FileDescriptor> readMessageFds(const struct msghdr *msg, size_t limit);
99 : static std::vector<int> getMessageFds(const google::protobuf::Message &message);
100 :
101 : static bool addReplyFileDescriptors(google::protobuf::Message *reply, std::vector<FileDescriptor> *fds);
102 :
103 : struct MethodCall
104 : {
105 : std::chrono::steady_clock::time_point timeoutDeadline;
106 : ClientControllerImpl *controller = nullptr;
107 : google::protobuf::Message *response = nullptr;
108 : google::protobuf::Closure *closure = nullptr;
109 : };
110 :
111 : void updateTimeoutTimer();
112 :
113 : static void complete(MethodCall *call);
114 : static void completeWithError(MethodCall *call, std::string reason);
115 :
116 : private:
117 : int m_sock;
118 : int m_epollFd;
119 : int m_timerFd;
120 : int m_eventFd;
121 :
122 : SimpleBufferPool m_sendBufPool;
123 :
124 : mutable std::mutex m_lock;
125 : std::atomic<uint64_t> m_serialCounter;
126 :
127 : const std::chrono::milliseconds m_timeout;
128 :
129 : std::map<uint64_t, MethodCall> m_methodCalls;
130 :
131 : std::mutex m_eventsLock;
132 :
133 : int m_eventTagCounter;
134 :
135 : struct Event
136 : {
137 : int id;
138 : const google::protobuf::Descriptor *descriptor;
139 : EventHandler handler;
140 : };
141 :
142 : std::multimap<std::string, Event> m_eventHandlers;
143 : };
144 :
145 : } // namespace firebolt::rialto::ipc
146 :
147 : #endif // FIREBOLT_RIALTO_IPC_IPC_CHANNEL_IMPL_H_
|