Line data Source code
1 : /*
2 : * Copyright (C) 2004-2026 Savoir-faire Linux Inc.
3 : *
4 : * This program is free software: you can redistribute it and/or modify
5 : * it under the terms of the GNU General Public License as published by
6 : * the Free Software Foundation, either version 3 of the License, or
7 : * (at your option) any later version.
8 : *
9 : * This program is distributed in the hope that it will be useful,
10 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 : * GNU General Public License for more details.
13 : *
14 : * You should have received a copy of the GNU General Public License
15 : * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 : */
17 : #include "jamidht/svc_discovery_channel_handler.h"
18 :
19 : #include "jamidht/account_manager.h"
20 : #include "jamidht/contact_list.h"
21 : #include "jamidht/service_manager.h"
22 : #include "logger.h"
23 :
24 : #include <algorithm>
25 : #include <cstring>
26 : #include <fstream>
27 :
28 : namespace jami {
29 :
30 : namespace {
31 :
32 : template<typename T>
33 : bool
34 2427 : sendMsg(const std::shared_ptr<dhtnet::ChannelSocket>& s, const T& msg)
35 : {
36 2427 : if (!s)
37 0 : return false;
38 2427 : msgpack::sbuffer buf;
39 2427 : msgpack::pack(buf, msg);
40 2427 : std::error_code ec;
41 2427 : auto sent = s->write(reinterpret_cast<const uint8_t*>(buf.data()), buf.size(), ec);
42 2427 : if (ec) {
43 0 : JAMI_WARNING("[SvcDiscovery] write error: {}", ec.message());
44 0 : return false;
45 : }
46 2427 : return sent == buf.size();
47 2427 : }
48 :
49 : constexpr const char* SVC_CACHE_FILENAME = "svc_discovery_cache.msgpack";
50 :
51 : } // namespace
52 :
53 672 : SvcDiscoveryChannelHandler::SvcDiscoveryChannelHandler(const std::shared_ptr<JamiAccount>& acc,
54 : dhtnet::ConnectionManager& cm,
55 672 : std::filesystem::path cachePath)
56 : : ChannelHandlerInterface()
57 672 : , account_(acc)
58 672 : , connectionManager_(cm)
59 672 : , state_(std::make_shared<State>())
60 1344 : , cachePath_(std::move(cachePath))
61 : {
62 672 : loadCache();
63 :
64 : // Install the default response callback that populates the cache.
65 672 : auto cachePathCopy = cachePath_;
66 1344 : state_->responseCb = [state = state_, cachePathCopy](const std::string& peerAccountUri,
67 : const std::string& peerDeviceId,
68 : const std::vector<svc_protocol::SvcInfo>& services) {
69 1208 : if (peerDeviceId.empty())
70 0 : return;
71 1208 : DeviceId dev;
72 : try {
73 1208 : dev = DeviceId(peerDeviceId);
74 0 : } catch (...) {
75 0 : return;
76 0 : }
77 1208 : CacheUpdateCb updateCb;
78 1208 : msgpack::sbuffer buf;
79 : {
80 1208 : std::lock_guard lk(state->mtx);
81 1208 : state->cache[dev] = State::CachedDeviceServices {peerAccountUri, services};
82 1207 : updateCb = state->cacheUpdateCb;
83 1207 : msgpack::pack(buf, state->cache);
84 1208 : }
85 : // Write to disk outside the lock
86 1208 : if (!cachePathCopy.empty()) {
87 1208 : std::error_code ec;
88 1208 : std::filesystem::create_directories(cachePathCopy, ec);
89 1208 : std::ofstream out(cachePathCopy / SVC_CACHE_FILENAME, std::ios::binary | std::ios::trunc);
90 1208 : if (out)
91 1208 : out.write(buf.data(), buf.size());
92 1208 : }
93 1208 : if (updateCb)
94 1208 : updateCb(peerAccountUri, dev, services);
95 1880 : };
96 672 : }
97 :
98 1344 : SvcDiscoveryChannelHandler::~SvcDiscoveryChannelHandler() = default;
99 :
100 : void
101 672 : SvcDiscoveryChannelHandler::onCacheUpdated(CacheUpdateCb cb)
102 : {
103 672 : std::lock_guard lk(state_->mtx);
104 672 : state_->cacheUpdateCb = std::move(cb);
105 672 : }
106 :
107 : void
108 0 : SvcDiscoveryChannelHandler::setOnResponse(ResponseCb cb)
109 : {
110 0 : std::lock_guard lk(state_->mtx);
111 0 : state_->responseCb = std::move(cb);
112 0 : }
113 :
114 : svc_protocol::SvcDiscResponse
115 1213 : SvcDiscoveryChannelHandler::buildResponse(JamiAccount& account, const std::string& peerAccountUri)
116 : {
117 1213 : svc_protocol::SvcDiscResponse out;
118 1213 : if (const auto& id = account.identity().first)
119 1213 : out.device = id->getPublicKey().getLongId().toString();
120 0 : auto checker = [&account](const std::string& uri) {
121 0 : return account.isContact(uri);
122 1213 : };
123 1213 : auto visible = account.serviceManager().getVisibleServices(peerAccountUri, checker);
124 1213 : out.services.reserve(visible.size());
125 1213 : for (auto& r : visible) {
126 0 : svc_protocol::SvcInfo info;
127 0 : info.id = std::move(r.id);
128 0 : info.name = std::move(r.name);
129 0 : info.description = std::move(r.description);
130 0 : info.proto = "tcp";
131 0 : info.scheme = std::move(r.scheme);
132 0 : out.services.push_back(std::move(info));
133 0 : }
134 2426 : return out;
135 1213 : }
136 :
137 : void
138 1219 : SvcDiscoveryChannelHandler::connect(const DeviceId& deviceId,
139 : const std::string& /*name*/,
140 : ConnectCb&& cb,
141 : const std::string& /*connectionType*/,
142 : bool /*forceNewConnection*/)
143 : {
144 1219 : auto userCb = std::make_shared<ConnectCb>(std::move(cb));
145 1219 : auto state = state_;
146 1219 : auto wacc = account_;
147 1219 : connectionManager_.connectDevice(deviceId,
148 2438 : std::string(svc_protocol::DiscoveryChannelName),
149 2438 : [userCb, state, wacc, this](std::shared_ptr<dhtnet::ChannelSocket> socket,
150 : const DeviceId& dev) {
151 1219 : if (socket) {
152 : // Retain the channel for its full lifetime so the response
153 : // can come back even if no one else holds it.
154 1214 : auto cert = socket->peerCertificate();
155 1214 : std::string peerAccountUri;
156 1214 : if (cert && cert->issuer)
157 1214 : peerAccountUri = cert->issuer->getId().toString();
158 : {
159 1214 : std::lock_guard lk(state->mtx);
160 1214 : state->channels[peerAccountUri].push_back(socket);
161 1214 : }
162 1214 : socket->onShutdown([state, ws = std::weak_ptr(socket), peerAccountUri](
163 : const std::error_code&) {
164 1214 : auto s = ws.lock();
165 1214 : if (!s)
166 0 : return;
167 1214 : std::lock_guard lk(state->mtx);
168 1214 : auto it = state->channels.find(peerAccountUri);
169 1214 : if (it != state->channels.end()) {
170 1214 : auto& vec = it->second;
171 1214 : vec.erase(std::remove(vec.begin(), vec.end(), s), vec.end());
172 1214 : if (vec.empty())
173 602 : state->channels.erase(it);
174 : }
175 1214 : });
176 : // The initiating side immediately writes a Query so the server
177 : // can respond. We need to install a reader to handle the
178 : // response too.
179 1214 : installReader(socket, peerAccountUri);
180 3642 : if (!sendMsg(socket, svc_protocol::SvcDiscQuery {}))
181 0 : JAMI_WARNING("[SvcDiscovery] failed to send SvcDiscQuery to {}",
182 : peerAccountUri);
183 1214 : }
184 1219 : if (*userCb)
185 1219 : (*userCb)(socket, dev);
186 1219 : });
187 2438 : }
188 :
189 : bool
190 1215 : SvcDiscoveryChannelHandler::onRequest(const std::shared_ptr<dht::crypto::Certificate>& peer, const std::string& /*name*/)
191 : {
192 1215 : return peer && peer->issuer;
193 : }
194 :
195 : void
196 2428 : SvcDiscoveryChannelHandler::onReady(const std::shared_ptr<dht::crypto::Certificate>& peer,
197 : const std::string& /*name*/,
198 : std::shared_ptr<dhtnet::ChannelSocket> channel)
199 : {
200 2428 : if (!channel)
201 0 : return;
202 : // The initiator already retains the socket and installs its reader from
203 : // connect(); no need to do the work twice.
204 2428 : if (channel->isInitiator())
205 1214 : return;
206 1214 : if (!peer || !peer->issuer) {
207 0 : channel->shutdown();
208 0 : return;
209 : }
210 1214 : auto peerUri = peer->issuer->getId().toString();
211 : {
212 1214 : std::lock_guard lk(state_->mtx);
213 1214 : state_->channels[peerUri].push_back(channel);
214 1214 : }
215 1214 : auto state = state_;
216 1214 : channel->onShutdown([state, ws = std::weak_ptr(channel), peerUri](const std::error_code&) {
217 1214 : auto s = ws.lock();
218 1214 : if (!s)
219 0 : return;
220 1214 : std::lock_guard lk(state->mtx);
221 1213 : auto it = state->channels.find(peerUri);
222 1214 : if (it != state->channels.end()) {
223 1214 : auto& vec = it->second;
224 1214 : vec.erase(std::remove(vec.begin(), vec.end(), s), vec.end());
225 1214 : if (vec.empty())
226 605 : state->channels.erase(it);
227 : }
228 1214 : });
229 1214 : installReader(channel, peer->issuer->getId().toString());
230 1213 : }
231 :
232 : void
233 2425 : SvcDiscoveryChannelHandler::installReader(const std::shared_ptr<dhtnet::ChannelSocket>& channel,
234 : std::string peerAccountUri)
235 : {
236 2425 : auto reader = std::make_shared<msgpack::unpacker>();
237 2428 : reader->reserve_buffer(4096);
238 2428 : auto wacc = account_;
239 2427 : auto state = state_;
240 2427 : std::weak_ptr<dhtnet::ChannelSocket> wsock = channel;
241 :
242 2427 : channel->setOnRecv([reader, wacc, state, wsock, peerAccountUri = std::move(peerAccountUri)](const uint8_t* data,
243 : size_t size) -> ssize_t {
244 2421 : if (size == 0)
245 0 : return 0;
246 2421 : if (reader->buffer_capacity() < size)
247 0 : reader->reserve_buffer(size);
248 2421 : std::memcpy(reader->buffer(), data, size);
249 2421 : reader->buffer_consumed(size);
250 :
251 2421 : msgpack::object_handle oh;
252 4842 : while (reader->next(oh)) {
253 2421 : const auto& obj = oh.get();
254 2421 : const auto type = svc_protocol::peekType(obj);
255 2421 : const auto v = svc_protocol::peekVersion(obj);
256 2421 : auto sock = wsock.lock();
257 2421 : if (!sock)
258 0 : return static_cast<ssize_t>(size);
259 :
260 2421 : if (type == svc_protocol::MsgType::Query) {
261 1213 : auto acc = wacc.lock();
262 1213 : if (!acc) {
263 0 : sock->shutdown();
264 0 : continue;
265 : }
266 1213 : if (v > svc_protocol::MaxVersion) {
267 0 : svc_protocol::SvcDiscVersionMismatch vm;
268 0 : vm.max_supported = svc_protocol::MaxVersion;
269 0 : sendMsg(sock, vm);
270 0 : continue;
271 0 : }
272 1213 : auto resp = SvcDiscoveryChannelHandler::buildResponse(*acc, peerAccountUri);
273 4852 : JAMI_LOG("[SvcDiscovery] returning {} service(s) to peer={}", resp.services.size(), peerAccountUri);
274 1213 : if (!sendMsg(sock, resp))
275 0 : JAMI_WARNING("[SvcDiscovery] failed to send SvcDiscResponse to {}", peerAccountUri);
276 2421 : } else if (type == svc_protocol::MsgType::ServiceList) {
277 1208 : svc_protocol::SvcDiscResponse resp;
278 : try {
279 1208 : obj.convert(resp);
280 0 : } catch (const std::exception& e) {
281 0 : JAMI_WARNING("[SvcDiscovery] bad service_list: {}", e.what());
282 0 : continue;
283 0 : }
284 1208 : ResponseCb cb;
285 : {
286 1208 : std::lock_guard lk(state->mtx);
287 1208 : cb = state->responseCb;
288 1208 : }
289 1208 : if (cb)
290 1208 : cb(peerAccountUri, resp.device, resp.services);
291 : else
292 0 : JAMI_WARNING("[SvcDiscovery] no responseCb set; dropping {} services from {}",
293 : resp.services.size(),
294 : peerAccountUri);
295 1208 : } else if (type == svc_protocol::MsgType::ServiceUpdate) {
296 0 : svc_protocol::SvcDiscServiceUpdate update;
297 : try {
298 0 : obj.convert(update);
299 0 : } catch (const std::exception& e) {
300 0 : JAMI_WARNING("[SvcDiscovery] bad service_update: {}", e.what());
301 0 : continue;
302 0 : }
303 0 : ResponseCb cb;
304 : {
305 0 : std::lock_guard lk(state->mtx);
306 0 : cb = state->responseCb;
307 0 : }
308 0 : if (cb)
309 0 : cb(peerAccountUri, update.device, update.services);
310 0 : } else if (type == svc_protocol::MsgType::VersionMismatch || type == svc_protocol::MsgType::Error) {
311 0 : ResponseCb cb;
312 : {
313 0 : std::lock_guard lk(state->mtx);
314 0 : cb = state->responseCb;
315 0 : }
316 0 : if (cb)
317 0 : cb(peerAccountUri, std::string {}, {});
318 0 : } else {
319 0 : JAMI_WARNING("[SvcDiscovery] unknown message type '{}'", type);
320 : }
321 2421 : }
322 2421 : return static_cast<ssize_t>(size);
323 2421 : });
324 2428 : }
325 :
326 : void
327 0 : SvcDiscoveryChannelHandler::broadcastServiceUpdate()
328 : {
329 0 : auto acc = account_.lock();
330 0 : if (!acc)
331 0 : return;
332 :
333 : // Snapshot connected peers and their channels under the lock.
334 0 : std::map<std::string, std::vector<std::shared_ptr<dhtnet::ChannelSocket>>> snapshot;
335 : {
336 0 : std::lock_guard lk(state_->mtx);
337 0 : snapshot = state_->channels;
338 0 : }
339 :
340 0 : if (snapshot.empty())
341 0 : return;
342 :
343 0 : std::size_t peerCount = 0;
344 0 : for (const auto& [peerUri, sockets] : snapshot) {
345 : // Build a per-peer filtered service list.
346 0 : auto resp = buildResponse(*acc, peerUri);
347 0 : svc_protocol::SvcDiscServiceUpdate update;
348 0 : update.device = resp.device;
349 0 : update.services = std::move(resp.services);
350 :
351 0 : for (const auto& sock : sockets)
352 0 : sendMsg(sock, update);
353 0 : ++peerCount;
354 0 : }
355 0 : JAMI_LOG("[SvcDiscovery] broadcast service_update to {} peer(s)", peerCount);
356 0 : }
357 :
358 : void
359 1219 : SvcDiscoveryChannelHandler::refreshDevice(const std::string& /*peerUri*/, const DeviceId& deviceId)
360 : {
361 3657 : connect(deviceId,
362 2438 : std::string(svc_protocol::DiscoveryChannelName),
363 1219 : [](std::shared_ptr<dhtnet::ChannelSocket>, const DeviceId&) {
364 : // Connection callback — nothing needed here; the cache is
365 : // updated when the response is read via installReader/responseCb.
366 1219 : });
367 1219 : }
368 :
369 : std::vector<SvcDiscoveryChannelHandler::CachedSvcInfo>
370 1834 : SvcDiscoveryChannelHandler::getCachedServices(const std::string& peerUri) const
371 : {
372 1834 : std::vector<CachedSvcInfo> result;
373 1834 : std::lock_guard lk(state_->mtx);
374 11411 : for (const auto& [dev, entry] : state_->cache) {
375 9576 : if (entry.peerUri == peerUri) {
376 1414 : for (const auto& svc : entry.services)
377 0 : result.push_back(CachedSvcInfo {dev, svc});
378 : }
379 : }
380 3668 : return result;
381 1834 : }
382 :
383 : void
384 0 : SvcDiscoveryChannelHandler::removeDevice(const DeviceId& deviceId)
385 : {
386 : {
387 0 : std::lock_guard lk(state_->mtx);
388 0 : state_->cache.erase(deviceId);
389 0 : }
390 0 : saveCache();
391 0 : }
392 :
393 : void
394 0 : SvcDiscoveryChannelHandler::saveCache() const
395 : {
396 0 : if (cachePath_.empty())
397 0 : return;
398 0 : std::error_code ec;
399 0 : std::filesystem::create_directories(cachePath_, ec);
400 0 : std::ofstream out(cachePath_ / SVC_CACHE_FILENAME, std::ios::binary | std::ios::trunc);
401 0 : if (out) {
402 0 : std::lock_guard lk(state_->mtx);
403 0 : msgpack::pack(out, state_->cache);
404 0 : } else {
405 0 : JAMI_WARNING("[SvcDiscovery] failed to save cache to disk");
406 : }
407 0 : }
408 :
409 : void
410 672 : SvcDiscoveryChannelHandler::loadCache()
411 : {
412 672 : if (cachePath_.empty())
413 0 : return;
414 672 : auto path = cachePath_ / SVC_CACHE_FILENAME;
415 672 : std::ifstream in(path, std::ios::binary);
416 672 : if (!in)
417 666 : return;
418 6 : std::string content((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
419 : try {
420 6 : auto oh = msgpack::unpack(content.data(), content.size());
421 6 : std::lock_guard lk(state_->mtx);
422 6 : oh.get().convert(state_->cache);
423 6 : } catch (const std::exception& e) {
424 0 : JAMI_WARNING("[SvcDiscovery] failed to load cache: {}", e.what());
425 0 : return;
426 0 : }
427 24 : JAMI_LOG("[SvcDiscovery] loaded {} cached device entries from disk", state_->cache.size());
428 1338 : }
429 :
430 : } // namespace jami
|