1
0
Fork 0

Codechange: use std::span for transferring data in network code

pull/14144/head
Rubidium 2025-04-27 20:08:45 +02:00 committed by rubidium42
parent b7e7f08f78
commit c6ea0ce961
9 changed files with 71 additions and 88 deletions

View File

@ -141,4 +141,22 @@ NetworkError GetSocketError(SOCKET d);
static_assert(sizeof(in_addr) == 4); ///< IPv4 addresses should be 4 bytes.
static_assert(sizeof(in6_addr) == 16); ///< IPv6 addresses should be 16 bytes.
struct SocketSender {
SOCKET sock;
ssize_t operator()(std::span<const uint8_t> buffer)
{
return send(this->sock, reinterpret_cast<const char *>(buffer.data()), static_cast<int>(buffer.size()), 0);
}
};
struct SocketReceiver {
SOCKET sock;
ssize_t operator()(std::span<uint8_t> buffer)
{
return recv(this->sock, reinterpret_cast<char *>(buffer.data()), static_cast<int>(buffer.size()), 0);
}
};
#endif /* NETWORK_CORE_OS_ABSTRACTION_H */

View File

@ -397,18 +397,18 @@ std::vector<uint8_t> Packet::Recv_buffer()
/**
* Extract at most the length of the span bytes from the packet into the span.
* @param span The span to write the bytes to.
* @param destination The span to write the bytes to.
* @return The number of bytes that were actually read.
*/
size_t Packet::Recv_bytes(std::span<uint8_t> span)
size_t Packet::Recv_bytes(std::span<uint8_t> destination)
{
auto tranfer_to_span = [](std::span<uint8_t> destination, const char *source, size_t amount) {
size_t to_copy = std::min(amount, destination.size());
std::copy(source, source + to_copy, destination.data());
return to_copy;
auto transfer_to_span = [&destination](std::span<const uint8_t> source) {
auto to_copy = source.subspan(0, destination.size());
std::ranges::copy(to_copy, destination.begin());
return to_copy.size();
};
return this->TransferOut(tranfer_to_span, span);
return this->TransferOut(transfer_to_span);
}
/**

View File

@ -95,29 +95,22 @@ public:
* Transfer data from the packet to the given function. It starts reading at the
* position the last transfer stopped.
* See Packet::TransferIn for more information about transferring data to functions.
* @param transfer_function The function to pass the buffer as second parameter and the
* amount to write as third parameter. It returns the amount that
* was written or -1 upon errors.
* @param transfer_function The function to pass span of bytes to write to.
* It returns the amount that was written or -1 upon errors.
* @param limit The maximum amount of bytes to transfer.
* @param destination The first parameter of the transfer function.
* @param args The fourth and further parameters to the transfer function, if any.
* @tparam F The type of the transfer_function.
* @return The return value of the transfer_function.
*/
template <
typename A = size_t, ///< The type for the amount to be passed, so it can be cast to the right type.
typename F, ///< The type of the function.
typename D, ///< The type of the destination.
typename ... Args> ///< The types of the remaining arguments to the function.
ssize_t TransferOutWithLimit(F transfer_function, size_t limit, D destination, Args&& ... args)
template <typename F>
ssize_t TransferOutWithLimit(F transfer_function, size_t limit)
{
size_t amount = std::min(this->RemainingBytesToTransfer(), limit);
if (amount == 0) return 0;
assert(this->pos < this->buffer.size());
assert(this->pos + amount <= this->buffer.size());
/* Making buffer a char means casting a lot in the Recv/Send functions. */
const char *output_buffer = reinterpret_cast<const char*>(this->buffer.data() + this->pos);
ssize_t bytes = transfer_function(destination, output_buffer, static_cast<A>(amount), std::forward<Args>(args)...);
auto output_buffer = std::span<const uint8_t>(this->buffer.data() + this->pos, amount);
ssize_t bytes = transfer_function(output_buffer);
if (bytes > 0) this->pos += bytes;
return bytes;
}
@ -126,21 +119,15 @@ public:
* Transfer data from the packet to the given function. It starts reading at the
* position the last transfer stopped.
* See Packet::TransferIn for more information about transferring data to functions.
* @param transfer_function The function to pass the buffer as second parameter and the
* amount to write as third parameter. It returns the amount that
* was written or -1 upon errors.
* @param destination The first parameter of the transfer function.
* @param args The fourth and further parameters to the transfer function, if any.
* @tparam A The type for the amount to be passed, so it can be cast to the right type.
* @param transfer_function The function to pass span of bytes to write to.
* It returns the amount that was written or -1 upon errors.
* @tparam F The type of the transfer_function.
* @tparam D The type of the destination.
* @tparam Args The types of the remaining arguments to the function.
* @return The return value of the transfer_function.
*/
template <typename A = size_t, typename F, typename D, typename ... Args>
ssize_t TransferOut(F transfer_function, D destination, Args&& ... args)
template <typename F>
ssize_t TransferOut(F transfer_function)
{
return TransferOutWithLimit<A>(transfer_function, std::numeric_limits<size_t>::max(), destination, std::forward<Args>(args)...);
return TransferOutWithLimit(transfer_function, std::numeric_limits<size_t>::max());
}
/**
@ -161,28 +148,21 @@ public:
*
* This will attempt to write all the remaining bytes into the packet. It updates the
* position based on how many bytes were actually written by the called transfer_function.
* @param transfer_function The function to pass the buffer as second parameter and the
* amount to read as third parameter. It returns the amount that
* was read or -1 upon errors.
* @param source The first parameter of the transfer function.
* @param args The fourth and further parameters to the transfer function, if any.
* @tparam A The type for the amount to be passed, so it can be cast to the right type.
* @param transfer_function The function to pass a span of bytes to read to.
* It returns the amount that was read or -1 upon errors.
* @tparam F The type of the transfer_function.
* @tparam S The type of the source.
* @tparam Args The types of the remaining arguments to the function.
* @return The return value of the transfer_function.
*/
template <typename A = size_t, typename F, typename S, typename ... Args>
ssize_t TransferIn(F transfer_function, S source, Args&& ... args)
template <typename F>
ssize_t TransferIn(F transfer_function)
{
size_t amount = this->RemainingBytesToTransfer();
if (amount == 0) return 0;
assert(this->pos < this->buffer.size());
assert(this->pos + amount <= this->buffer.size());
/* Making buffer a char means casting a lot in the Recv/Send functions. */
char *input_buffer = reinterpret_cast<char*>(this->buffer.data() + this->pos);
ssize_t bytes = transfer_function(source, input_buffer, static_cast<A>(amount), std::forward<Args>(args)...);
auto input_buffer = std::span<uint8_t>(this->buffer.data() + this->pos, amount);
ssize_t bytes = transfer_function(input_buffer);
if (bytes > 0) this->pos += bytes;
return bytes;
}

View File

@ -81,7 +81,7 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down)
while (!this->packet_queue.empty()) {
Packet &p = *this->packet_queue.front();
ssize_t res = p.TransferOut<int>(send, this->sock, 0);
ssize_t res = p.TransferOut(SocketSender{this->sock});
if (res == -1) {
NetworkError err = NetworkError::GetLast();
if (!err.WouldBlock()) {
@ -131,7 +131,7 @@ std::unique_ptr<Packet> NetworkTCPSocketHandler::ReceivePacket()
/* Read packet size */
if (!p.HasPacketSizeData()) {
while (p.RemainingBytesToTransfer() != 0) {
res = p.TransferIn<int>(recv, this->sock, 0);
res = p.TransferIn(SocketReceiver{this->sock});
if (res == -1) {
NetworkError err = NetworkError::GetLast();
if (!err.WouldBlock()) {
@ -159,7 +159,7 @@ std::unique_ptr<Packet> NetworkTCPSocketHandler::ReceivePacket()
/* Read rest of packet */
while (p.RemainingBytesToTransfer() != 0) {
res = p.TransferIn<int>(recv, this->sock, 0);
res = p.TransferIn(SocketReceiver{this->sock});
if (res == -1) {
NetworkError err = NetworkError::GetLast();
if (!err.WouldBlock()) {

View File

@ -43,7 +43,7 @@ public:
Debug(net, 2, "[{}] Banned ip tried to join ({}), refused", Tsocket::GetName(), entry);
if (p.TransferOut<int>(send, s, 0) < 0) {
if (p.TransferOut(SocketSender{s}) < 0) {
Debug(net, 0, "[{}] send failed: {}", Tsocket::GetName(), NetworkError::GetLast().AsString());
}
closesocket(s);
@ -58,7 +58,7 @@ public:
Packet p(nullptr, Tfull_packet);
p.PrepareToSend();
if (p.TransferOut<int>(send, s, 0) < 0) {
if (p.TransferOut(SocketSender{s}) < 0) {
Debug(net, 0, "[{}] send failed: {}", Tsocket::GetName(), NetworkError::GetLast().AsString());
}
closesocket(s);

View File

@ -94,7 +94,9 @@ void NetworkUDPSocketHandler::SendPacket(Packet &p, NetworkAddress &recv, bool a
}
/* Send the buffer */
ssize_t res = p.TransferOut<int>(sendto, s.first, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength());
ssize_t res = p.TransferOut([&](std::span<const uint8_t> buffer) {
return sendto(s.first, reinterpret_cast<const char *>(buffer.data()), static_cast<int>(buffer.size()), 0, reinterpret_cast<const struct sockaddr *>(send.GetAddress()), send.GetAddressLength());
});
Debug(net, 7, "sendto({})", send.GetAddressAsString());
/* Check for any errors, but ignore it otherwise */
@ -120,7 +122,9 @@ void NetworkUDPSocketHandler::ReceivePackets()
/* Try to receive anything */
SetNonBlocking(s.first); // Some OSes seem to lose the non-blocking status of the socket
ssize_t nbytes = p.TransferIn<int>(recvfrom, s.first, 0, (struct sockaddr *)&client_addr, &client_len);
ssize_t nbytes = p.TransferIn([&](std::span<uint8_t> buffer) {
return recvfrom(s.first, reinterpret_cast<char *>(buffer.data()), static_cast<int>(buffer.size()), 0, reinterpret_cast<struct sockaddr *>(&client_addr), &client_len);
});
/* Did we get the bytes for the base header of the packet? */
if (nbytes <= 0) break; // No data, i.e. no packet

View File

@ -52,19 +52,6 @@ struct PacketReader : LoadFilter {
{
}
/**
* Simple wrapper around fwrite to be able to pass it to Packet's TransferOut.
* @param destination The reader to add the data to.
* @param source The buffer to read data from.
* @param amount The number of bytes to copy.
* @return The number of bytes that were copied.
*/
static inline ssize_t TransferOutMemCopy(PacketReader *destination, const char *source, size_t amount)
{
std::copy_n(source, amount, std::back_inserter(destination->buffer));
return amount;
}
/**
* Add a packet to this buffer.
* @param p The packet to add.
@ -72,7 +59,10 @@ struct PacketReader : LoadFilter {
void AddPacket(Packet &p)
{
assert(this->read_bytes == 0);
p.TransferOut(TransferOutMemCopy, this);
p.TransferOut([this](std::span<const uint8_t> source) {
std::ranges::copy(source, std::back_inserter(this->buffer));
return source.size();
});
}
size_t Read(uint8_t *rbuf, size_t size) override

View File

@ -446,18 +446,6 @@ static bool GunzipFile(const ContentInfo &ci)
#endif /* defined(WITH_ZLIB) */
}
/**
* Simple wrapper around fwrite to be able to pass it to Packet's TransferOut.
* @param file The file to write data to.
* @param buffer The buffer to write to the file.
* @param amount The number of bytes to write.
* @return The number of bytes that were written.
*/
static inline ssize_t TransferOutFWrite(std::optional<FileHandle> &file, const char *buffer, size_t amount)
{
return fwrite(buffer, 1, amount, *file);
}
bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet &p)
{
if (!this->cur_file.has_value()) {
@ -474,8 +462,11 @@ bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet &p)
}
} else {
/* We have a file opened, thus are downloading internal content */
size_t to_read = p.RemainingBytesToTransfer();
if (to_read != 0 && static_cast<size_t>(p.TransferOut(TransferOutFWrite, std::ref(this->cur_file))) != to_read) {
ssize_t to_read = p.RemainingBytesToTransfer();
auto write_to_disk = [this](std::span<const uint8_t> buffer) {
return fwrite(buffer.data(), 1, buffer.size(), *this->cur_file);
};
if (to_read != 0 && p.TransferOut(write_to_disk) != to_read) {
CloseWindowById(WC_NETWORK_STATUS_WINDOW, WN_NETWORK_STATUS_WINDOW_CONTENT_DOWNLOAD);
ShowErrorMessage(
GetEncodedString(STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD),

View File

@ -39,14 +39,14 @@ static std::tuple<Packet, bool> CreatePacketForReading(Packet &source, MockNetwo
Packet dest(socket_handler, COMPAT_MTU, source.Size());
auto transfer_in = [](Packet &source, char *dest_data, size_t length) {
auto transfer_out = [](char *dest_data, const char *source_data, size_t length) {
std::copy(source_data, source_data + length, dest_data);
return length;
auto transfer_in = [&source](std::span<uint8_t> dest_data) {
auto transfer_out = [&dest_data](std::span<const uint8_t> source_data) {
std::ranges::copy(source_data, dest_data.begin());
return source_data.size();
};
return source.TransferOutWithLimit(transfer_out, length, dest_data);
return source.TransferOutWithLimit(transfer_out, dest_data.size());
};
dest.TransferIn(transfer_in, source);
dest.TransferIn(transfer_in);
bool valid = dest.PrepareToRead();
dest.Recv_uint8(); // Ignore the type