Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #ifndef FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
21 : #define FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
22 :
23 : #include "FileDescriptor.h"
24 : #include "IIpcServer.h"
25 : #include "IIpcServerFactory.h"
26 : #include "IpcServerControllerImpl.h"
27 : #include "SimpleBufferPool.h"
28 :
29 : #include "rialtoipc-transport.pb.h"
30 :
31 : #include <sys/socket.h>
32 :
33 : #include <atomic>
34 : #include <list>
35 : #include <map>
36 : #include <memory>
37 : #include <mutex>
38 : #include <set>
39 : #include <string>
40 : #include <thread>
41 : #include <vector>
42 :
43 : #if !defined(SCM_MAX_FD)
44 : #define SCM_MAX_FD 255
45 : #endif
46 :
47 : namespace firebolt::rialto::ipc
48 : {
49 : class ClientImpl;
50 :
51 : class ServerFactory : public IServerFactory
52 : {
53 : public:
54 26 : ServerFactory() = default;
55 26 : ~ServerFactory() override = default;
56 :
57 : std::shared_ptr<IServer> create() override;
58 : };
59 :
60 : class ServerImpl final : public ::firebolt::rialto::ipc::IServer, public std::enable_shared_from_this<ServerImpl>
61 : {
62 : public:
63 : explicit ServerImpl();
64 : ~ServerImpl() final;
65 :
66 : public:
67 : bool addSocket(const std::string &socketPath, std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
68 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb) override;
69 :
70 : std::shared_ptr<IClient>
71 : addClient(int socketFd, std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb) override;
72 :
73 : int fd() const override;
74 : bool wait(int timeoutMSecs) override;
75 : bool process() override;
76 :
77 : protected:
78 : friend class ClientImpl;
79 : bool sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &message);
80 : bool isClientConnected(uint64_t clientId) const;
81 : void disconnectClient(uint64_t clientId);
82 :
83 : private:
84 : struct Socket;
85 : static bool getSocketLock(Socket *socket);
86 : static void closeListeningSocket(Socket *socket);
87 :
88 : void wakeEventLoop() const;
89 :
90 : void processNewConnection(uint64_t socketId);
91 :
92 : void processClientSocket(uint64_t clientId, unsigned events);
93 : void processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
94 : const std::vector<FileDescriptor> &fds = {});
95 :
96 : void processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
97 : const std::vector<FileDescriptor> &fds);
98 :
99 : std::shared_ptr<ClientImpl> addClientSocket(int socketFd, const std::string &listeningSocketPath,
100 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb);
101 :
102 : void sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg);
103 :
104 : void sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
105 : __attribute__((format(printf, 4, 5)));
106 :
107 : void handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response);
108 :
109 : std::shared_ptr<msghdr> populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
110 : google::protobuf::Message *response);
111 : std::shared_ptr<msghdr> populateErrorReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
112 : const std::string &reason);
113 :
114 : private:
115 : static const size_t m_kMaxMessageLen;
116 :
117 : int m_pollFd;
118 : int m_wakeEventFd;
119 :
120 : std::atomic<uint64_t> m_socketIdCounter;
121 : std::atomic<uint64_t> m_clientIdCounter;
122 :
123 : struct Socket
124 : {
125 : int sockFd = -1;
126 : int lockFd = -1;
127 : std::string sockPath;
128 : std::string lockPath;
129 : std::function<void(const std::shared_ptr<IClient> &)> connectedCb;
130 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
131 : };
132 :
133 : std::mutex m_socketsLock;
134 : std::map<uint64_t, Socket> m_sockets;
135 :
136 : struct ClientDetails
137 : {
138 : int sock = -1;
139 : std::shared_ptr<ClientImpl> client;
140 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
141 : };
142 :
143 : mutable std::mutex m_clientsLock;
144 :
145 : std::map<uint64_t, ClientDetails> m_clients;
146 : std::set<uint64_t> m_condemnedClients;
147 :
148 : uint8_t m_recvDataBuf[128 * 1024];
149 : uint8_t m_recvCtrlBuf[SCM_MAX_FD * sizeof(int)];
150 :
151 : SimpleBufferPool m_sendBufPool;
152 : };
153 :
154 : } // namespace firebolt::rialto::ipc
155 :
156 : #endif // FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
|