Line data Source code
1 : /*
2 : * If not stated otherwise in this file or this component's LICENSE file the
3 : * following copyright and licenses apply:
4 : *
5 : * Copyright 2022 Sky UK
6 : *
7 : * Licensed under the Apache License, Version 2.0 (the "License");
8 : * you may not use this file except in compliance with the License.
9 : * You may obtain a copy of the License at
10 : *
11 : * http://www.apache.org/licenses/LICENSE-2.0
12 : *
13 : * Unless required by applicable law or agreed to in writing, software
14 : * distributed under the License is distributed on an "AS IS" BASIS,
15 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 : * See the License for the specific language governing permissions and
17 : * limitations under the License.
18 : */
19 :
20 : #ifndef FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
21 : #define FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
22 :
23 : #include "FileDescriptor.h"
24 : #include "IIpcServer.h"
25 : #include "IIpcServerFactory.h"
26 : #include "IpcServerControllerImpl.h"
27 : #include "SimpleBufferPool.h"
28 :
29 : #include "rialtoipc-transport.pb.h"
30 :
31 : #include <sys/socket.h>
32 :
33 : #include <atomic>
34 : #include <list>
35 : #include <map>
36 : #include <memory>
37 : #include <mutex>
38 : #include <set>
39 : #include <string>
40 : #include <thread>
41 : #include <vector>
42 :
43 : #if !defined(SCM_MAX_FD)
44 : #define SCM_MAX_FD 255
45 : #endif
46 :
47 : namespace firebolt::rialto::ipc
48 : {
49 : class ClientImpl;
50 :
51 : class ServerFactory : public IServerFactory
52 : {
53 : public:
54 32 : ServerFactory() = default;
55 32 : ~ServerFactory() override = default;
56 :
57 : std::shared_ptr<IServer> create() override;
58 : };
59 :
60 : class ServerImpl final : public ::firebolt::rialto::ipc::IServer, public std::enable_shared_from_this<ServerImpl>
61 : {
62 : public:
63 : explicit ServerImpl();
64 : ~ServerImpl() final;
65 :
66 : public:
67 : bool addSocket(const std::string &socketPath, std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
68 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb) override;
69 : bool addSocket(int fd, std::function<void(const std::shared_ptr<IClient> &)> clientConnectedCb,
70 : std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb) override;
71 :
72 : std::shared_ptr<IClient>
73 : addClient(int socketFd, std::function<void(const std::shared_ptr<IClient> &)> clientDisconnectedCb) override;
74 :
75 : int fd() const override;
76 : bool wait(int timeoutMSecs) override;
77 : bool process() override;
78 :
79 : protected:
80 : friend class ClientImpl;
81 : bool sendEvent(uint64_t clientId, const std::shared_ptr<google::protobuf::Message> &message);
82 : bool isClientConnected(uint64_t clientId) const;
83 : void disconnectClient(uint64_t clientId);
84 :
85 : private:
86 : struct Socket;
87 : static bool getSocketLock(Socket *socket);
88 : static void closeListeningSocket(Socket *socket);
89 :
90 : void wakeEventLoop() const;
91 :
92 : void processNewConnection(uint64_t socketId);
93 :
94 : void processClientSocket(uint64_t clientId, unsigned events);
95 : void processClientMessage(const std::shared_ptr<ClientImpl> &client, const uint8_t *data, size_t dataLen,
96 : const std::vector<FileDescriptor> &fds = {});
97 :
98 : void processMethodCall(const std::shared_ptr<ClientImpl> &client, const transport::MethodCall &call,
99 : const std::vector<FileDescriptor> &fds);
100 :
101 : std::shared_ptr<ClientImpl> addClientSocket(int socketFd, const std::string &listeningSocketPath,
102 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb);
103 :
104 : void sendReply(uint64_t clientId, const std::shared_ptr<msghdr> &msg);
105 :
106 : void sendErrorReply(const std::shared_ptr<ClientImpl> &client, uint64_t serialId, const char *format, ...)
107 : __attribute__((format(printf, 4, 5)));
108 :
109 : void handleResponse(ServerControllerImpl *controller, google::protobuf::Message *response);
110 :
111 : std::shared_ptr<msghdr> populateReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
112 : google::protobuf::Message *response);
113 : std::shared_ptr<msghdr> populateErrorReply(const std::shared_ptr<const ClientImpl> &client, uint64_t serialId,
114 : const std::string &reason);
115 :
116 : private:
117 : static const size_t m_kMaxMessageLen;
118 :
119 : int m_pollFd;
120 : int m_wakeEventFd;
121 :
122 : std::atomic<uint64_t> m_socketIdCounter;
123 : std::atomic<uint64_t> m_clientIdCounter;
124 :
125 : struct Socket
126 : {
127 : int sockFd = -1;
128 : int lockFd = -1;
129 : std::string sockPath;
130 : std::string lockPath;
131 : std::function<void(const std::shared_ptr<IClient> &)> connectedCb;
132 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
133 : bool isOwned = true;
134 : };
135 :
136 : std::mutex m_socketsLock;
137 : std::map<uint64_t, Socket> m_sockets;
138 :
139 : struct ClientDetails
140 : {
141 : int sock = -1;
142 : std::shared_ptr<ClientImpl> client;
143 : std::function<void(const std::shared_ptr<IClient> &)> disconnectedCb;
144 : };
145 :
146 : mutable std::mutex m_clientsLock;
147 :
148 : std::map<uint64_t, ClientDetails> m_clients;
149 : std::set<uint64_t> m_condemnedClients;
150 :
151 : uint8_t m_recvDataBuf[128 * 1024];
152 : uint8_t m_recvCtrlBuf[SCM_MAX_FD * sizeof(int)];
153 :
154 : SimpleBufferPool m_sendBufPool;
155 : };
156 :
157 : } // namespace firebolt::rialto::ipc
158 :
159 : #endif // FIREBOLT_RIALTO_IPC_IPC_SERVER_IMPL_H_
|