Line data Source code
1 : /*
2 : * Copyright (C) 2004-2025 Savoir-faire Linux Inc.
3 : *
4 : * This program is free software: you can redistribute it and/or modify
5 : * it under the terms of the GNU General Public License as published by
6 : * the Free Software Foundation, either version 3 of the License, or
7 : * (at your option) any later version.
8 : *
9 : * This program is distributed in the hope that it will be useful,
10 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 : * GNU General Public License for more details.
13 : *
14 : * You should have received a copy of the GNU General Public License
15 : * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 : */
17 : #include "jamidht/message_channel_handler.h"
18 :
19 : #include <dhtnet/channel_utils.h>
20 : #include <string_view>
21 :
22 : using namespace std::literals;
23 :
24 : static constexpr auto MESSAGE_SCHEME = "msg:"sv;
25 :
26 : namespace jami {
27 :
28 : using Key = std::pair<std::string, DeviceId>;
29 :
30 : struct MessageChannelHandler::Impl : public std::enable_shared_from_this<Impl>
31 : {
32 : dhtnet::ConnectionManager& connectionManager_;
33 : OnMessage onMessage_;
34 : OnPeerStateChanged onPeerStateChanged_;
35 : std::recursive_mutex connectionsMtx_;
36 : std::map<std::string, std::map<DeviceId, std::vector<std::shared_ptr<dhtnet::ChannelSocket>>>> connections_;
37 :
38 669 : Impl(dhtnet::ConnectionManager& cm, OnMessage onMessage, OnPeerStateChanged onPeer)
39 1338 : : connectionManager_(cm)
40 669 : , onMessage_(std::move(onMessage))
41 1338 : , onPeerStateChanged_(std::move(onPeer))
42 669 : {}
43 :
44 : void onChannelShutdown(const std::shared_ptr<dhtnet::ChannelSocket>& socket,
45 : const std::string& peerId,
46 : const DeviceId& device);
47 : };
48 :
49 669 : MessageChannelHandler::MessageChannelHandler(dhtnet::ConnectionManager& cm,
50 : OnMessage onMessage,
51 669 : OnPeerStateChanged onPeer)
52 : : ChannelHandlerInterface()
53 669 : , pimpl_(std::make_shared<Impl>(cm, std::move(onMessage), std::move(onPeer)))
54 669 : {}
55 :
56 1338 : MessageChannelHandler::~MessageChannelHandler()
57 : {
58 669 : std::unique_lock lk(pimpl_->connectionsMtx_);
59 1180 : for (const auto& [peerId, _] : pimpl_->connections_) {
60 511 : pimpl_->onPeerStateChanged_(peerId, false);
61 : }
62 669 : auto connections = std::move(pimpl_->connections_);
63 669 : pimpl_->connections_.clear();
64 669 : lk.unlock();
65 1338 : }
66 :
67 : void
68 1371 : MessageChannelHandler::connect(const DeviceId& deviceId,
69 : const std::string&,
70 : ConnectCb&& cb,
71 : const std::string& connectionType,
72 : bool forceNewConnection)
73 : {
74 1371 : auto channelName = concat(MESSAGE_SCHEME, deviceId.to_view());
75 1371 : if (pimpl_->connectionManager_.isConnecting(deviceId, channelName)) {
76 2420 : JAMI_LOG("Already connecting to {}", deviceId);
77 605 : return;
78 : }
79 765 : pimpl_->connectionManager_
80 765 : .connectDevice(deviceId, channelName, std::move(cb), false, forceNewConnection, connectionType);
81 1371 : }
82 :
83 : void
84 686 : MessageChannelHandler::Impl::onChannelShutdown(const std::shared_ptr<dhtnet::ChannelSocket>& socket,
85 : const std::string& peerId,
86 : const DeviceId& device)
87 : {
88 686 : std::lock_guard lk(connectionsMtx_);
89 686 : auto peerIt = connections_.find(peerId);
90 686 : if (peerIt == connections_.end()) {
91 32 : JAMI_WARNING("onChannelShutdown: No connections found for peer {}", peerId);
92 8 : return;
93 : }
94 678 : auto connectionsIt = peerIt->second.find(device);
95 678 : if (connectionsIt == peerIt->second.end()) {
96 0 : JAMI_WARNING("onChannelShutdown: No connections found for device {} of peer {}", device.toString(), peerId);
97 0 : return;
98 : }
99 678 : auto& connections = connectionsIt->second;
100 678 : auto conn = std::find(connections.begin(), connections.end(), socket);
101 678 : if (conn != connections.end())
102 678 : connections.erase(conn);
103 678 : if (connections.empty()) {
104 573 : peerIt->second.erase(connectionsIt);
105 : }
106 678 : if (peerIt->second.empty()) {
107 571 : connections_.erase(peerIt);
108 571 : onPeerStateChanged_(peerId, false);
109 : }
110 686 : }
111 :
112 : std::shared_ptr<dhtnet::ChannelSocket>
113 14248 : MessageChannelHandler::getChannel(const std::string& peer, const DeviceId& deviceId) const
114 : {
115 14248 : std::lock_guard lk(pimpl_->connectionsMtx_);
116 14248 : auto it = pimpl_->connections_.find(peer);
117 14248 : if (it == pimpl_->connections_.end())
118 1424 : return nullptr;
119 12824 : auto deviceIt = it->second.find(deviceId);
120 12824 : if (deviceIt == it->second.end())
121 6 : return nullptr;
122 12818 : if (deviceIt->second.empty())
123 0 : return nullptr;
124 12818 : return deviceIt->second.back();
125 14248 : }
126 :
127 : std::vector<std::shared_ptr<dhtnet::ChannelSocket>>
128 3324 : MessageChannelHandler::getChannels(const std::string& peer) const
129 : {
130 3324 : std::vector<std::shared_ptr<dhtnet::ChannelSocket>> sockets;
131 3324 : std::lock_guard lk(pimpl_->connectionsMtx_);
132 3324 : auto it = pimpl_->connections_.find(peer);
133 3324 : if (it == pimpl_->connections_.end())
134 1786 : return sockets;
135 1538 : sockets.reserve(it->second.size());
136 3076 : for (auto& [deviceId, channels] : it->second) {
137 3288 : for (auto& channel : channels) {
138 1750 : sockets.push_back(channel);
139 : }
140 : }
141 1538 : return sockets;
142 3324 : }
143 :
144 : bool
145 668 : MessageChannelHandler::onRequest(const std::shared_ptr<dht::crypto::Certificate>& cert, const std::string& /* name */)
146 : {
147 668 : if (!cert || !cert->issuer)
148 0 : return false;
149 668 : return true;
150 : }
151 :
152 : void
153 1296 : MessageChannelHandler::onReady(const std::shared_ptr<dht::crypto::Certificate>& cert,
154 : const std::string&,
155 : std::shared_ptr<dhtnet::ChannelSocket> socket)
156 : {
157 1296 : if (!cert || !cert->issuer)
158 0 : return;
159 1296 : auto peerId = cert->issuer->getId().toString();
160 1296 : auto device = cert->getLongId();
161 1296 : std::lock_guard lk(pimpl_->connectionsMtx_);
162 1293 : auto& connections = pimpl_->connections_[peerId];
163 1295 : bool newPeerConnection = connections.empty();
164 1294 : auto& deviceConnections = connections[device];
165 1296 : deviceConnections.push_back(socket);
166 1294 : if (newPeerConnection)
167 1080 : pimpl_->onPeerStateChanged_(peerId, true);
168 :
169 1296 : socket->setOnRecv(dhtnet::buildMsgpackReader<Message>([onMessage = pimpl_->onMessage_, cert](Message&& msg) {
170 12538 : onMessage(cert, msg.t, msg.c);
171 12546 : return std::error_code();
172 : }));
173 :
174 2592 : socket->onShutdown(
175 1296 : [w = pimpl_->weak_from_this(), peerId, device, s = std::weak_ptr(socket)](const std::error_code& /*ec*/) {
176 1296 : if (auto shared = w.lock())
177 1296 : shared->onChannelShutdown(s.lock(), peerId, device);
178 1296 : });
179 1295 : }
180 :
181 : void
182 2 : MessageChannelHandler::closeChannel(const std::string& peer,
183 : const DeviceId& device,
184 : const std::shared_ptr<dhtnet::ChannelSocket>& conn)
185 : {
186 2 : if (!conn)
187 0 : return;
188 2 : std::unique_lock lk(pimpl_->connectionsMtx_);
189 2 : auto it = pimpl_->connections_.find(peer);
190 2 : if (it != pimpl_->connections_.end()) {
191 0 : auto deviceIt = it->second.find(device);
192 0 : if (deviceIt != it->second.end()) {
193 0 : auto& channels = deviceIt->second;
194 0 : channels.erase(std::remove(channels.begin(), channels.end(), conn), channels.end());
195 0 : if (channels.empty()) {
196 0 : it->second.erase(deviceIt);
197 0 : if (it->second.empty()) {
198 0 : pimpl_->connections_.erase(it);
199 0 : pimpl_->onPeerStateChanged_(peer, false);
200 : }
201 : }
202 : }
203 : }
204 2 : lk.unlock();
205 2 : conn->stop();
206 2 : }
207 :
208 : bool
209 12550 : MessageChannelHandler::sendMessage(const std::shared_ptr<dhtnet::ChannelSocket>& socket, const Message& message)
210 : {
211 12550 : if (!socket)
212 0 : return false;
213 12550 : msgpack::sbuffer buffer(UINT16_MAX); // Use max
214 12550 : msgpack::pack(buffer, message);
215 12545 : std::error_code ec;
216 12541 : auto sent = socket->write(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size(), ec);
217 12548 : if (ec) {
218 8 : JAMI_WARNING("Error sending message: {:s}", ec.message());
219 : }
220 12548 : return !ec && sent == buffer.size();
221 12547 : }
222 :
223 : } // namespace jami
|