diff --git a/src/network/core/os_abstraction.h b/src/network/core/os_abstraction.h index f4d090bbd6..0b4abc0fcf 100644 --- a/src/network/core/os_abstraction.h +++ b/src/network/core/os_abstraction.h @@ -141,4 +141,22 @@ NetworkError GetSocketError(SOCKET d); static_assert(sizeof(in_addr) == 4); ///< IPv4 addresses should be 4 bytes. static_assert(sizeof(in6_addr) == 16); ///< IPv6 addresses should be 16 bytes. +struct SocketSender { + SOCKET sock; + + ssize_t operator()(std::span buffer) + { + return send(this->sock, reinterpret_cast(buffer.data()), static_cast(buffer.size()), 0); + } +}; + +struct SocketReceiver { + SOCKET sock; + + ssize_t operator()(std::span buffer) + { + return recv(this->sock, reinterpret_cast(buffer.data()), static_cast(buffer.size()), 0); + } +}; + #endif /* NETWORK_CORE_OS_ABSTRACTION_H */ diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index ef43de7ed1..69058c130b 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -397,18 +397,18 @@ std::vector Packet::Recv_buffer() /** * Extract at most the length of the span bytes from the packet into the span. - * @param span The span to write the bytes to. + * @param destination The span to write the bytes to. * @return The number of bytes that were actually read. */ -size_t Packet::Recv_bytes(std::span span) +size_t Packet::Recv_bytes(std::span destination) { - auto tranfer_to_span = [](std::span destination, const char *source, size_t amount) { - size_t to_copy = std::min(amount, destination.size()); - std::copy(source, source + to_copy, destination.data()); - return to_copy; + auto transfer_to_span = [&destination](std::span source) { + auto to_copy = source.subspan(0, destination.size()); + std::ranges::copy(to_copy, destination.begin()); + return to_copy.size(); }; - return this->TransferOut(tranfer_to_span, span); + return this->TransferOut(transfer_to_span); } /** diff --git a/src/network/core/packet.h b/src/network/core/packet.h index d1ffa33c80..0219f18fde 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -95,29 +95,22 @@ public: * Transfer data from the packet to the given function. It starts reading at the * position the last transfer stopped. * See Packet::TransferIn for more information about transferring data to functions. - * @param transfer_function The function to pass the buffer as second parameter and the - * amount to write as third parameter. It returns the amount that - * was written or -1 upon errors. + * @param transfer_function The function to pass span of bytes to write to. + * It returns the amount that was written or -1 upon errors. * @param limit The maximum amount of bytes to transfer. - * @param destination The first parameter of the transfer function. - * @param args The fourth and further parameters to the transfer function, if any. + * @tparam F The type of the transfer_function. * @return The return value of the transfer_function. */ - template < - typename A = size_t, ///< The type for the amount to be passed, so it can be cast to the right type. - typename F, ///< The type of the function. - typename D, ///< The type of the destination. - typename ... Args> ///< The types of the remaining arguments to the function. - ssize_t TransferOutWithLimit(F transfer_function, size_t limit, D destination, Args&& ... args) + template + ssize_t TransferOutWithLimit(F transfer_function, size_t limit) { size_t amount = std::min(this->RemainingBytesToTransfer(), limit); if (amount == 0) return 0; assert(this->pos < this->buffer.size()); assert(this->pos + amount <= this->buffer.size()); - /* Making buffer a char means casting a lot in the Recv/Send functions. */ - const char *output_buffer = reinterpret_cast(this->buffer.data() + this->pos); - ssize_t bytes = transfer_function(destination, output_buffer, static_cast(amount), std::forward(args)...); + auto output_buffer = std::span(this->buffer.data() + this->pos, amount); + ssize_t bytes = transfer_function(output_buffer); if (bytes > 0) this->pos += bytes; return bytes; } @@ -126,21 +119,15 @@ public: * Transfer data from the packet to the given function. It starts reading at the * position the last transfer stopped. * See Packet::TransferIn for more information about transferring data to functions. - * @param transfer_function The function to pass the buffer as second parameter and the - * amount to write as third parameter. It returns the amount that - * was written or -1 upon errors. - * @param destination The first parameter of the transfer function. - * @param args The fourth and further parameters to the transfer function, if any. - * @tparam A The type for the amount to be passed, so it can be cast to the right type. - * @tparam F The type of the transfer_function. - * @tparam D The type of the destination. - * @tparam Args The types of the remaining arguments to the function. + * @param transfer_function The function to pass span of bytes to write to. + * It returns the amount that was written or -1 upon errors. + * @tparam F The type of the transfer_function. * @return The return value of the transfer_function. */ - template - ssize_t TransferOut(F transfer_function, D destination, Args&& ... args) + template + ssize_t TransferOut(F transfer_function) { - return TransferOutWithLimit(transfer_function, std::numeric_limits::max(), destination, std::forward(args)...); + return TransferOutWithLimit(transfer_function, std::numeric_limits::max()); } /** @@ -161,28 +148,21 @@ public: * * This will attempt to write all the remaining bytes into the packet. It updates the * position based on how many bytes were actually written by the called transfer_function. - * @param transfer_function The function to pass the buffer as second parameter and the - * amount to read as third parameter. It returns the amount that - * was read or -1 upon errors. - * @param source The first parameter of the transfer function. - * @param args The fourth and further parameters to the transfer function, if any. - * @tparam A The type for the amount to be passed, so it can be cast to the right type. - * @tparam F The type of the transfer_function. - * @tparam S The type of the source. - * @tparam Args The types of the remaining arguments to the function. + * @param transfer_function The function to pass a span of bytes to read to. + * It returns the amount that was read or -1 upon errors. + * @tparam F The type of the transfer_function. * @return The return value of the transfer_function. */ - template - ssize_t TransferIn(F transfer_function, S source, Args&& ... args) + template + ssize_t TransferIn(F transfer_function) { size_t amount = this->RemainingBytesToTransfer(); if (amount == 0) return 0; assert(this->pos < this->buffer.size()); assert(this->pos + amount <= this->buffer.size()); - /* Making buffer a char means casting a lot in the Recv/Send functions. */ - char *input_buffer = reinterpret_cast(this->buffer.data() + this->pos); - ssize_t bytes = transfer_function(source, input_buffer, static_cast(amount), std::forward(args)...); + auto input_buffer = std::span(this->buffer.data() + this->pos, amount); + ssize_t bytes = transfer_function(input_buffer); if (bytes > 0) this->pos += bytes; return bytes; } diff --git a/src/network/core/tcp.cpp b/src/network/core/tcp.cpp index a03eed0b74..d29a530564 100644 --- a/src/network/core/tcp.cpp +++ b/src/network/core/tcp.cpp @@ -81,7 +81,7 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down) while (!this->packet_queue.empty()) { Packet &p = *this->packet_queue.front(); - ssize_t res = p.TransferOut(send, this->sock, 0); + ssize_t res = p.TransferOut(SocketSender{this->sock}); if (res == -1) { NetworkError err = NetworkError::GetLast(); if (!err.WouldBlock()) { @@ -131,7 +131,7 @@ std::unique_ptr NetworkTCPSocketHandler::ReceivePacket() /* Read packet size */ if (!p.HasPacketSizeData()) { while (p.RemainingBytesToTransfer() != 0) { - res = p.TransferIn(recv, this->sock, 0); + res = p.TransferIn(SocketReceiver{this->sock}); if (res == -1) { NetworkError err = NetworkError::GetLast(); if (!err.WouldBlock()) { @@ -159,7 +159,7 @@ std::unique_ptr NetworkTCPSocketHandler::ReceivePacket() /* Read rest of packet */ while (p.RemainingBytesToTransfer() != 0) { - res = p.TransferIn(recv, this->sock, 0); + res = p.TransferIn(SocketReceiver{this->sock}); if (res == -1) { NetworkError err = NetworkError::GetLast(); if (!err.WouldBlock()) { diff --git a/src/network/core/tcp_listen.h b/src/network/core/tcp_listen.h index b68e25b825..b2a7b85ee4 100644 --- a/src/network/core/tcp_listen.h +++ b/src/network/core/tcp_listen.h @@ -43,7 +43,7 @@ public: Debug(net, 2, "[{}] Banned ip tried to join ({}), refused", Tsocket::GetName(), entry); - if (p.TransferOut(send, s, 0) < 0) { + if (p.TransferOut(SocketSender{s}) < 0) { Debug(net, 0, "[{}] send failed: {}", Tsocket::GetName(), NetworkError::GetLast().AsString()); } closesocket(s); @@ -58,7 +58,7 @@ public: Packet p(nullptr, Tfull_packet); p.PrepareToSend(); - if (p.TransferOut(send, s, 0) < 0) { + if (p.TransferOut(SocketSender{s}) < 0) { Debug(net, 0, "[{}] send failed: {}", Tsocket::GetName(), NetworkError::GetLast().AsString()); } closesocket(s); diff --git a/src/network/core/udp.cpp b/src/network/core/udp.cpp index d3e69fce85..45900d1118 100644 --- a/src/network/core/udp.cpp +++ b/src/network/core/udp.cpp @@ -94,7 +94,9 @@ void NetworkUDPSocketHandler::SendPacket(Packet &p, NetworkAddress &recv, bool a } /* Send the buffer */ - ssize_t res = p.TransferOut(sendto, s.first, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength()); + ssize_t res = p.TransferOut([&](std::span buffer) { + return sendto(s.first, reinterpret_cast(buffer.data()), static_cast(buffer.size()), 0, reinterpret_cast(send.GetAddress()), send.GetAddressLength()); + }); Debug(net, 7, "sendto({})", send.GetAddressAsString()); /* Check for any errors, but ignore it otherwise */ @@ -120,7 +122,9 @@ void NetworkUDPSocketHandler::ReceivePackets() /* Try to receive anything */ SetNonBlocking(s.first); // Some OSes seem to lose the non-blocking status of the socket - ssize_t nbytes = p.TransferIn(recvfrom, s.first, 0, (struct sockaddr *)&client_addr, &client_len); + ssize_t nbytes = p.TransferIn([&](std::span buffer) { + return recvfrom(s.first, reinterpret_cast(buffer.data()), static_cast(buffer.size()), 0, reinterpret_cast(&client_addr), &client_len); + }); /* Did we get the bytes for the base header of the packet? */ if (nbytes <= 0) break; // No data, i.e. no packet diff --git a/src/network/network_client.cpp b/src/network/network_client.cpp index bb210fe3f6..38a5cc26cb 100644 --- a/src/network/network_client.cpp +++ b/src/network/network_client.cpp @@ -52,19 +52,6 @@ struct PacketReader : LoadFilter { { } - /** - * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut. - * @param destination The reader to add the data to. - * @param source The buffer to read data from. - * @param amount The number of bytes to copy. - * @return The number of bytes that were copied. - */ - static inline ssize_t TransferOutMemCopy(PacketReader *destination, const char *source, size_t amount) - { - std::copy_n(source, amount, std::back_inserter(destination->buffer)); - return amount; - } - /** * Add a packet to this buffer. * @param p The packet to add. @@ -72,7 +59,10 @@ struct PacketReader : LoadFilter { void AddPacket(Packet &p) { assert(this->read_bytes == 0); - p.TransferOut(TransferOutMemCopy, this); + p.TransferOut([this](std::span source) { + std::ranges::copy(source, std::back_inserter(this->buffer)); + return source.size(); + }); } size_t Read(uint8_t *rbuf, size_t size) override diff --git a/src/network/network_content.cpp b/src/network/network_content.cpp index f577e53a71..8728d931a2 100644 --- a/src/network/network_content.cpp +++ b/src/network/network_content.cpp @@ -446,18 +446,6 @@ static bool GunzipFile(const ContentInfo &ci) #endif /* defined(WITH_ZLIB) */ } -/** - * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut. - * @param file The file to write data to. - * @param buffer The buffer to write to the file. - * @param amount The number of bytes to write. - * @return The number of bytes that were written. - */ -static inline ssize_t TransferOutFWrite(std::optional &file, const char *buffer, size_t amount) -{ - return fwrite(buffer, 1, amount, *file); -} - bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet &p) { if (!this->cur_file.has_value()) { @@ -474,8 +462,11 @@ bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet &p) } } else { /* We have a file opened, thus are downloading internal content */ - size_t to_read = p.RemainingBytesToTransfer(); - if (to_read != 0 && static_cast(p.TransferOut(TransferOutFWrite, std::ref(this->cur_file))) != to_read) { + ssize_t to_read = p.RemainingBytesToTransfer(); + auto write_to_disk = [this](std::span buffer) { + return fwrite(buffer.data(), 1, buffer.size(), *this->cur_file); + }; + if (to_read != 0 && p.TransferOut(write_to_disk) != to_read) { CloseWindowById(WC_NETWORK_STATUS_WINDOW, WN_NETWORK_STATUS_WINDOW_CONTENT_DOWNLOAD); ShowErrorMessage( GetEncodedString(STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD), diff --git a/src/tests/test_network_crypto.cpp b/src/tests/test_network_crypto.cpp index a1acb4e1db..c3d0184780 100644 --- a/src/tests/test_network_crypto.cpp +++ b/src/tests/test_network_crypto.cpp @@ -39,14 +39,14 @@ static std::tuple CreatePacketForReading(Packet &source, MockNetwo Packet dest(socket_handler, COMPAT_MTU, source.Size()); - auto transfer_in = [](Packet &source, char *dest_data, size_t length) { - auto transfer_out = [](char *dest_data, const char *source_data, size_t length) { - std::copy(source_data, source_data + length, dest_data); - return length; + auto transfer_in = [&source](std::span dest_data) { + auto transfer_out = [&dest_data](std::span source_data) { + std::ranges::copy(source_data, dest_data.begin()); + return source_data.size(); }; - return source.TransferOutWithLimit(transfer_out, length, dest_data); + return source.TransferOutWithLimit(transfer_out, dest_data.size()); }; - dest.TransferIn(transfer_in, source); + dest.TransferIn(transfer_in); bool valid = dest.PrepareToRead(); dest.Recv_uint8(); // Ignore the type