Line data Source code
1 : /*
2 : * Copyright (C) 2004-2025 Savoir-faire Linux Inc.
3 : *
4 : * This program is free software: you can redistribute it and/or modify
5 : * it under the terms of the GNU General Public License as published by
6 : * the Free Software Foundation, either version 3 of the License, or
7 : * (at your option) any later version.
8 : *
9 : * This program is distributed in the hope that it will be useful,
10 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 : * GNU General Public License for more details.
13 : *
14 : * You should have received a copy of the GNU General Public License
15 : * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 : */
17 : #include "jamidht/message_channel_handler.h"
18 :
19 : static constexpr const char MESSAGE_SCHEME[] {"msg:"};
20 :
21 : namespace jami {
22 :
23 : using Key = std::pair<std::string, DeviceId>;
24 :
25 : struct MessageChannelHandler::Impl : public std::enable_shared_from_this<Impl>
26 : {
27 : dhtnet::ConnectionManager& connectionManager_;
28 : OnMessage onMessage_;
29 : OnPeerStateChanged onPeerStateChanged_;
30 : std::recursive_mutex connectionsMtx_;
31 : std::map<std::string, std::map<DeviceId, std::vector<std::shared_ptr<dhtnet::ChannelSocket>>>> connections_;
32 :
33 569 : Impl(dhtnet::ConnectionManager& cm, OnMessage onMessage, OnPeerStateChanged onPeer)
34 1138 : : connectionManager_(cm)
35 569 : , onMessage_(std::move(onMessage))
36 1138 : , onPeerStateChanged_(std::move(onPeer))
37 569 : {}
38 :
39 : void onChannelShutdown(const std::shared_ptr<dhtnet::ChannelSocket>& socket,
40 : const std::string& peerId,
41 : const DeviceId& device);
42 : };
43 :
44 569 : MessageChannelHandler::MessageChannelHandler(dhtnet::ConnectionManager& cm,
45 569 : OnMessage onMessage, OnPeerStateChanged onPeer)
46 : : ChannelHandlerInterface()
47 569 : , pimpl_(std::make_shared<Impl>(cm, std::move(onMessage), std::move(onPeer)))
48 569 : {}
49 :
50 1138 : MessageChannelHandler::~MessageChannelHandler() {}
51 :
52 : void
53 1355 : MessageChannelHandler::connect(const DeviceId& deviceId,
54 : const std::string&,
55 : ConnectCb&& cb,
56 : const std::string& connectionType,
57 : bool forceNewConnection)
58 : {
59 1355 : auto channelName = MESSAGE_SCHEME + deviceId.toString();
60 1355 : if (pimpl_->connectionManager_.isConnecting(deviceId, channelName)) {
61 1935 : JAMI_LOG("Already connecting to {}", deviceId);
62 645 : return;
63 : }
64 1420 : pimpl_->connectionManager_.connectDevice(deviceId,
65 : channelName,
66 710 : std::move(cb),
67 : false,
68 : forceNewConnection,
69 : connectionType);
70 1355 : }
71 :
72 : void
73 646 : MessageChannelHandler::Impl::onChannelShutdown(const std::shared_ptr<dhtnet::ChannelSocket>& socket,
74 : const std::string& peerId,
75 : const DeviceId& device)
76 : {
77 646 : std::lock_guard lk(connectionsMtx_);
78 646 : auto peerIt = connections_.find(peerId);
79 646 : if (peerIt == connections_.end())
80 1 : return;
81 645 : auto connectionsIt = peerIt->second.find(device);
82 645 : if (connectionsIt == peerIt->second.end())
83 0 : return;
84 645 : auto& connections = connectionsIt->second;
85 645 : auto conn = std::find(connections.begin(), connections.end(), socket);
86 645 : if (conn != connections.end())
87 645 : connections.erase(conn);
88 645 : if (connections.empty()) {
89 532 : peerIt->second.erase(connectionsIt);
90 : }
91 645 : if (peerIt->second.empty()) {
92 528 : connections_.erase(peerIt);
93 528 : onPeerStateChanged_(peerId, false);
94 : }
95 646 : }
96 :
97 : std::shared_ptr<dhtnet::ChannelSocket>
98 14918 : MessageChannelHandler::getChannel(const std::string& peer, const DeviceId& deviceId) const
99 : {
100 14918 : std::lock_guard lk(pimpl_->connectionsMtx_);
101 14918 : auto it = pimpl_->connections_.find(peer);
102 14918 : if (it == pimpl_->connections_.end())
103 1390 : return nullptr;
104 13528 : auto deviceIt = it->second.find(deviceId);
105 13528 : if (deviceIt == it->second.end())
106 10 : return nullptr;
107 13518 : if (deviceIt->second.empty())
108 0 : return nullptr;
109 13518 : return deviceIt->second.back();
110 14918 : }
111 :
112 : std::vector<std::shared_ptr<dhtnet::ChannelSocket>>
113 3158 : MessageChannelHandler::getChannels(const std::string& peer) const
114 : {
115 3158 : std::vector<std::shared_ptr<dhtnet::ChannelSocket>> sockets;
116 3158 : std::lock_guard lk(pimpl_->connectionsMtx_);
117 3158 : auto it = pimpl_->connections_.find(peer);
118 3158 : if (it == pimpl_->connections_.end())
119 1630 : return sockets;
120 1528 : sockets.reserve(it->second.size());
121 3057 : for (auto& [deviceId, channels] : it->second) {
122 3289 : for (auto& channel : channels) {
123 1760 : sockets.push_back(channel);
124 : }
125 : }
126 1528 : return sockets;
127 3158 : }
128 :
129 : bool
130 636 : MessageChannelHandler::onRequest(const std::shared_ptr<dht::crypto::Certificate>& cert,
131 : const std::string& /* name */)
132 : {
133 636 : if (!cert || !cert->issuer)
134 0 : return false;
135 636 : return true;
136 : }
137 :
138 : void
139 1230 : MessageChannelHandler::onReady(const std::shared_ptr<dht::crypto::Certificate>& cert,
140 : const std::string&,
141 : std::shared_ptr<dhtnet::ChannelSocket> socket)
142 : {
143 1230 : if (!cert || !cert->issuer)
144 0 : return;
145 1230 : auto peerId = cert->issuer->getId().toString();
146 1230 : auto device = cert->getLongId();
147 1230 : std::lock_guard lk(pimpl_->connectionsMtx_);
148 1230 : auto& connections = pimpl_->connections_[peerId];
149 1230 : bool newPeerConnection = connections.empty();
150 1230 : auto& deviceConnections = connections[device];
151 1230 : deviceConnections.push_back(socket);
152 1230 : if (newPeerConnection)
153 997 : pimpl_->onPeerStateChanged_(peerId, true);
154 :
155 1230 : socket->onShutdown([w = pimpl_->weak_from_this(), peerId, device, s = std::weak_ptr(socket)]() {
156 1230 : if (auto shared = w.lock())
157 1230 : shared->onChannelShutdown(s.lock(), peerId, device);
158 1230 : });
159 :
160 : struct DecodingContext
161 : {
162 79494 : msgpack::unpacker pac {[](msgpack::type::object_type, std::size_t, void*) { return true; },
163 : nullptr,
164 : 16 * 1024};
165 : };
166 :
167 1230 : socket->setOnRecv([onMessage = pimpl_->onMessage_,
168 : peerId,
169 : cert,
170 : ctx = std::make_shared<DecodingContext>()](const uint8_t* buf, size_t len) {
171 13249 : if (!buf)
172 0 : return len;
173 :
174 13249 : ctx->pac.reserve_buffer(len);
175 13249 : std::copy_n(buf, len, ctx->pac.buffer());
176 13249 : ctx->pac.buffer_consumed(len);
177 :
178 13249 : msgpack::object_handle oh;
179 : try {
180 26498 : while (ctx->pac.next(oh)) {
181 13249 : Message msg;
182 13249 : oh.get().convert(msg);
183 13249 : onMessage(cert, msg.t, msg.c);
184 13249 : }
185 0 : } catch (const std::exception& e) {
186 0 : JAMI_WARNING("[convInfo] error on sync: {:s}", e.what());
187 0 : }
188 13249 : return len;
189 13249 : });
190 1230 : }
191 :
192 : void
193 1 : MessageChannelHandler::closeChannel(const std::string& peer, const DeviceId& device, const std::shared_ptr<dhtnet::ChannelSocket>& conn)
194 : {
195 1 : if (!conn)
196 0 : return;
197 1 : std::unique_lock lk(pimpl_->connectionsMtx_);
198 1 : auto it = pimpl_->connections_.find(peer);
199 1 : if (it != pimpl_->connections_.end()) {
200 1 : auto deviceIt = it->second.find(device);
201 1 : if (deviceIt != it->second.end()) {
202 1 : auto& channels = deviceIt->second;
203 1 : channels.erase(std::remove(channels.begin(), channels.end(), conn), channels.end());
204 1 : if (channels.empty()) {
205 1 : it->second.erase(deviceIt);
206 1 : if (it->second.empty()) {
207 1 : pimpl_->connections_.erase(it);
208 : }
209 : }
210 : }
211 : }
212 1 : lk.unlock();
213 1 : conn->stop();
214 1 : }
215 :
216 : bool
217 13250 : MessageChannelHandler::sendMessage(const std::shared_ptr<dhtnet::ChannelSocket>& socket,
218 : const Message& message)
219 : {
220 13250 : if (!socket)
221 0 : return false;
222 13250 : msgpack::sbuffer buffer(UINT16_MAX); // Use max
223 13250 : msgpack::pack(buffer, message);
224 13250 : std::error_code ec;
225 13250 : auto sent = socket->write(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size(), ec);
226 13250 : if (ec) {
227 3 : JAMI_WARNING("Error sending message: {:s}", ec.message());
228 : }
229 13250 : return !ec && sent == buffer.size();
230 13250 : }
231 :
232 : } // namespace jami
|