Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #ifndef FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
21 : #define FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
22 :
23 : #include "FileDescriptor.h"
24 : #include "IIpcServer.h"
25 : #include "IIpcServerFactory.h"
26 : #include "IpcServerControllerImpl.h"
27 : #include "SimpleBufferPool.h"
28 :
29 : #include "rialtoipc-transport.pb.h"
30 :
31 : #include <sys/socket.h>
32 :
33 : #include <atomic>
34 : #include <list>
35 : #include <map>
36 : #include <memory>
37 : #include <mutex>
38 : #include <set>
39 : #include <string>
40 : #include <thread>
41 : #include <vector>
42 :
43 : #if !defined(SCM_MAX_FD)
44 : #define SCM_MAX_FD 255
45 : #endif
46 :
47 : namespace firebolt::rialto::ipc
48 : {
49 : class ClientImpl;
50 :
51 : class ServerFactory : public IServerFactory
52 : {
53 : public:
54 32 : ServerFactory() = default;
55 32 : ~ServerFactory() override = default;
56 :
57 : std::shared_ptr<IServer> create() override;
58 : };
59 :
60 : class ServerImpl final : public ::firebolt::rialto::ipc::IServer, public std::enable_shared_from_this<ServerImpl>
61 : {
62 : public:
63 : explicit ServerImpl();
64 : ~ServerImpl() final;
65 :
66 : public:
67 : bool addSocket(const std::string &socketPath,
68 : const std::function<void(const std::shared_ptr<IClient> &)> &clientConnectedCb,
69 : const std::function<void(const std::shared_ptr<IClient> &)> &clientDisconnectedCb) override;
70 : bool addSocket(int fd, const std::function<void(const std::shared_ptr<IClient> &)> &clientConnectedCb,
71 : const std::function<void(const std::shared_ptr<IClient> &)> &clientDisconnectedCb) override;
72 :
73 : std::shared_ptr<IClient>
74 : addClient(int socketFd, std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb) override;
75 :
76 : int fd() const override;
77 : bool wait(int timeoutMSecs) override;
78 : bool process() override;
79 :
80 : protected:
81 : friend class ClientImpl;
82 : bool sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &message);
83 : bool isClientConnected(uint64_t clientId) const;
84 : void disconnectClient(uint64_t clientId);
85 :
86 : private:
87 : struct Socket;
88 : static bool getSocketLock(Socket *socket);
89 : static void closeListeningSocket(Socket *socket);
90 :
91 : void wakeEventLoop() const;
92 :
93 : void processNewConnection(uint64_t socketId);
94 :
95 : void processClientSocket(uint64_t clientId, unsigned events);
96 : void processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
97 : const std::vector<FileDescriptor> &fds = {});
98 :
99 : void processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
100 : const std::vector<FileDescriptor> &fds);
101 :
102 : std::shared_ptr<ClientImpl> addClientSocket(int socketFd, const std::string &listeningSocketPath,
103 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb);
104 :
105 : void sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg);
106 :
107 : void sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
108 : __attribute__((format(printf, 4, 5)));
109 :
110 : void handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response);
111 :
112 : std::shared_ptr<msghdr> populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
113 : google::protobuf::Message *response);
114 : std::shared_ptr<msghdr> populateErrorReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
115 : const std::string &reason);
116 :
117 : private:
118 : static const size_t m_kMaxMessageLen;
119 :
120 : int m_pollFd;
121 : int m_wakeEventFd;
122 :
123 : std::atomic<uint64_t> m_socketIdCounter;
124 : std::atomic<uint64_t> m_clientIdCounter;
125 :
126 : struct Socket
127 : {
128 : int sockFd = -1;
129 : int lockFd = -1;
130 : std::string sockPath;
131 : std::string lockPath;
132 : std::function<void(const std::shared_ptr<IClient> &)> connectedCb;
133 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
134 : bool isOwned = true;
135 : };
136 :
137 : std::mutex m_socketsLock;
138 : std::map<uint64_t, Socket> m_sockets;
139 :
140 : struct ClientDetails
141 : {
142 : int sock = -1;
143 : std::shared_ptr<ClientImpl> client;
144 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
145 : };
146 :
147 : mutable std::mutex m_clientsLock;
148 :
149 : std::map<uint64_t, ClientDetails> m_clients;
150 : std::set<uint64_t> m_condemnedClients;
151 :
152 : uint8_t m_recvDataBuf[128 * 1024];
153 : uint8_t m_recvCtrlBuf[SCM_MAX_FD * sizeof(int)];
154 :
155 : SimpleBufferPool m_sendBufPool;
156 : };
157 :
158 : } // namespace firebolt::rialto::ipc
159 :
160 : #endif // FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
|