From 4b372b60500485fa3bc9cbf9675550c9701464c1 Mon Sep 17 00:00:00 2001 From: Rubidium Date: Fri, 2 Feb 2024 18:13:38 +0100 Subject: [PATCH] Codechange: use std::shared_ptr to manage saveload filters instead of manually trying to avoid double frees --- src/network/network_client.cpp | 19 ++++--------- src/network/network_client.h | 2 +- src/network/network_server.cpp | 2 +- src/network/network_server.h | 2 +- src/openttd.cpp | 2 +- src/saveload/saveload.cpp | 50 ++++++++++++++-------------------- src/saveload/saveload.h | 4 +-- src/saveload/saveload_filter.h | 18 ++++++------ 8 files changed, 40 insertions(+), 59 deletions(-) diff --git a/src/network/network_client.cpp b/src/network/network_client.cpp index 6c9fd88273..670b7db47a 100644 --- a/src/network/network_client.cpp +++ b/src/network/network_client.cpp @@ -155,7 +155,6 @@ ClientNetworkGameSocketHandler::~ClientNetworkGameSocketHandler() assert(ClientNetworkGameSocketHandler::my_client == this); ClientNetworkGameSocketHandler::my_client = nullptr; - delete this->savegame; delete this->GetInfo(); } @@ -567,7 +566,7 @@ bool ClientNetworkGameSocketHandler::IsConnected() * DEF_CLIENT_RECEIVE_COMMAND has parameter: Packet *p ************/ -extern bool SafeLoad(const std::string &filename, SaveLoadOperation fop, DetailedFileType dft, GameMode newgm, Subdirectory subdir, struct LoadFilter *lf = nullptr); +extern bool SafeLoad(const std::string &filename, SaveLoadOperation fop, DetailedFileType dft, GameMode newgm, Subdirectory subdir, std::shared_ptr lf); NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_FULL(Packet *) { @@ -810,7 +809,7 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_MAP_BEGIN(Packe if (this->savegame != nullptr) return NETWORK_RECV_STATUS_MALFORMED_PACKET; - this->savegame = new PacketReader(); + this->savegame = std::make_shared(); _frame_counter = _frame_counter_server = _frame_counter_max = p->Recv_uint32(); @@ -864,20 +863,12 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_MAP_DONE(Packet _network_join_status = NETWORK_JOIN_STATUS_PROCESSING; SetWindowDirty(WC_NETWORK_STATUS_WINDOW, WN_NETWORK_STATUS_WINDOW_JOIN); - /* - * Make sure everything is set for reading. - * - * We need the local copy and reset this->savegame because when - * loading fails the network gets reset upon loading the intro - * game, which would cause us to free this->savegame twice. - */ - LoadFilter *lf = this->savegame; - this->savegame = nullptr; - lf->Reset(); + this->savegame->Reset(); /* The map is done downloading, load it */ ClearErrorMessages(); - bool load_success = SafeLoad({}, SLO_LOAD, DFT_GAME_FILE, GM_NORMAL, NO_DIRECTORY, lf); + bool load_success = SafeLoad({}, SLO_LOAD, DFT_GAME_FILE, GM_NORMAL, NO_DIRECTORY, this->savegame); + this->savegame = nullptr; /* Long savegame loads shouldn't affect the lag calculation! */ this->last_packet = std::chrono::steady_clock::now(); diff --git a/src/network/network_client.h b/src/network/network_client.h index 42caac9289..038ed0e551 100644 --- a/src/network/network_client.h +++ b/src/network/network_client.h @@ -16,7 +16,7 @@ class ClientNetworkGameSocketHandler : public ZeroedMemoryAllocator, public NetworkGameSocketHandler { private: std::string connection_string; ///< Address we are connected to. - struct PacketReader *savegame; ///< Packet reader for reading the savegame. + std::shared_ptr savegame; ///< Packet reader for reading the savegame. byte token; ///< The token we need to send back to the server to prove we're the right client. /** Status of the connection with the server. */ diff --git a/src/network/network_server.cpp b/src/network/network_server.cpp index 96c95434e5..1dda7dbcf4 100644 --- a/src/network/network_server.cpp +++ b/src/network/network_server.cpp @@ -576,7 +576,7 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::SendMap() Debug(net, 9, "client[{}] SendMap(): first_packet", this->client_id); WaitTillSaved(); - this->savegame = new PacketWriter(this); + this->savegame = std::make_shared(this); /* Now send the _frame_counter and how many packets are coming */ Packet *p = new Packet(PACKET_SERVER_MAP_BEGIN); diff --git a/src/network/network_server.h b/src/network/network_server.h index 171a89233d..2ff1c9798f 100644 --- a/src/network/network_server.h +++ b/src/network/network_server.h @@ -69,7 +69,7 @@ public: CommandQueue outgoing_queue; ///< The command-queue awaiting delivery size_t receive_limit; ///< Amount of bytes that we can receive at this moment - struct PacketWriter *savegame; ///< Writer used to write the savegame. + std::shared_ptr savegame; ///< Writer used to write the savegame. NetworkAddress client_address; ///< IP-address of the client (so they can be banned) ServerNetworkGameSocketHandler(SOCKET s); diff --git a/src/openttd.cpp b/src/openttd.cpp index 3af31509e1..31db6b0d60 100644 --- a/src/openttd.cpp +++ b/src/openttd.cpp @@ -966,7 +966,7 @@ static void MakeNewEditorWorld() * @param subdir default directory to look for filename, set to 0 if not needed * @param lf Load filter to use, if nullptr: use filename + subdir. */ -bool SafeLoad(const std::string &filename, SaveLoadOperation fop, DetailedFileType dft, GameMode newgm, Subdirectory subdir, struct LoadFilter *lf = nullptr) +bool SafeLoad(const std::string &filename, SaveLoadOperation fop, DetailedFileType dft, GameMode newgm, Subdirectory subdir, std::shared_ptr lf = nullptr) { assert(fop == SLO_LOAD); assert(dft == DFT_GAME_FILE || (lf == nullptr && dft == DFT_OLD_GAME_FILE)); diff --git a/src/saveload/saveload.cpp b/src/saveload/saveload.cpp index c4e078cf39..a25267ddf0 100644 --- a/src/saveload/saveload.cpp +++ b/src/saveload/saveload.cpp @@ -88,14 +88,14 @@ struct ReadBuffer { byte buf[MEMORY_CHUNK_SIZE]; ///< Buffer we're going to read from. byte *bufp; ///< Location we're at reading the buffer. byte *bufe; ///< End of the buffer we can read from. - LoadFilter *reader; ///< The filter used to actually read. + std::shared_ptr reader; ///< The filter used to actually read. size_t read; ///< The amount of read bytes so far from the filter. /** * Initialise our variables. * @param reader The filter to actually read data. */ - ReadBuffer(LoadFilter *reader) : bufp(nullptr), bufe(nullptr), reader(reader), read(0) + ReadBuffer(std::shared_ptr reader) : bufp(nullptr), bufe(nullptr), reader(reader), read(0) { } @@ -162,7 +162,7 @@ struct MemoryDumper { * Flush this dumper into a writer. * @param writer The filter we want to use. */ - void Flush(SaveFilter *writer) + void Flush(std::shared_ptr writer) { uint i = 0; size_t t = this->GetSize(); @@ -199,10 +199,10 @@ struct SaveLoadParams { bool expect_table_header; ///< In the case of a table, if the header is saved/loaded. MemoryDumper *dumper; ///< Memory dumper to write the savegame to. - SaveFilter *sf; ///< Filter to write the savegame to. + std::shared_ptr sf; ///< Filter to write the savegame to. ReadBuffer *reader; ///< Savegame reading buffer. - LoadFilter *lf; ///< Filter to read the savegame from. + std::shared_ptr lf; ///< Filter to read the savegame from. StringID error_str; ///< the translatable error message to show std::string extra_msg; ///< the error message @@ -2175,9 +2175,6 @@ struct FileReader : LoadFilter { { if (this->file != nullptr) fclose(this->file); this->file = nullptr; - - /* Make sure we don't double free. */ - _sl.sf = nullptr; } size_t Read(byte *buf, size_t size) override @@ -2213,9 +2210,6 @@ struct FileWriter : SaveFilter { ~FileWriter() { this->Finish(); - - /* Make sure we don't double free. */ - _sl.sf = nullptr; } void Write(byte *buf, size_t size) override @@ -2249,7 +2243,7 @@ struct LZOLoadFilter : LoadFilter { * Initialise this filter. * @param chain The next filter in this chain. */ - LZOLoadFilter(LoadFilter *chain) : LoadFilter(chain) + LZOLoadFilter(std::shared_ptr chain) : LoadFilter(chain) { if (lzo_init() != LZO_E_OK) SlError(STR_GAME_SAVELOAD_ERROR_BROKEN_INTERNAL_ERROR, "cannot initialize decompressor"); } @@ -2296,7 +2290,7 @@ struct LZOSaveFilter : SaveFilter { * Initialise this filter. * @param chain The next filter in this chain. */ - LZOSaveFilter(SaveFilter *chain, byte) : SaveFilter(chain) + LZOSaveFilter(std::shared_ptr chain, byte) : SaveFilter(chain) { if (lzo_init() != LZO_E_OK) SlError(STR_GAME_SAVELOAD_ERROR_BROKEN_INTERNAL_ERROR, "cannot initialize compressor"); } @@ -2336,7 +2330,7 @@ struct NoCompLoadFilter : LoadFilter { * Initialise this filter. * @param chain The next filter in this chain. */ - NoCompLoadFilter(LoadFilter *chain) : LoadFilter(chain) + NoCompLoadFilter(std::shared_ptr chain) : LoadFilter(chain) { } @@ -2352,7 +2346,7 @@ struct NoCompSaveFilter : SaveFilter { * Initialise this filter. * @param chain The next filter in this chain. */ - NoCompSaveFilter(SaveFilter *chain, byte) : SaveFilter(chain) + NoCompSaveFilter(std::shared_ptr chain, byte) : SaveFilter(chain) { } @@ -2378,7 +2372,7 @@ struct ZlibLoadFilter : LoadFilter { * Initialise this filter. * @param chain The next filter in this chain. */ - ZlibLoadFilter(LoadFilter *chain) : LoadFilter(chain) + ZlibLoadFilter(std::shared_ptr chain) : LoadFilter(chain) { memset(&this->z, 0, sizeof(this->z)); if (inflateInit(&this->z) != Z_OK) SlError(STR_GAME_SAVELOAD_ERROR_BROKEN_INTERNAL_ERROR, "cannot initialize decompressor"); @@ -2423,7 +2417,7 @@ struct ZlibSaveFilter : SaveFilter { * @param chain The next filter in this chain. * @param compression_level The requested level of compression. */ - ZlibSaveFilter(SaveFilter *chain, byte compression_level) : SaveFilter(chain) + ZlibSaveFilter(std::shared_ptr chain, byte compression_level) : SaveFilter(chain) { memset(&this->z, 0, sizeof(this->z)); if (deflateInit(&this->z, compression_level) != Z_OK) SlError(STR_GAME_SAVELOAD_ERROR_BROKEN_INTERNAL_ERROR, "cannot initialize compressor"); @@ -2507,7 +2501,7 @@ struct LZMALoadFilter : LoadFilter { * Initialise this filter. * @param chain The next filter in this chain. */ - LZMALoadFilter(LoadFilter *chain) : LoadFilter(chain), lzma(_lzma_init) + LZMALoadFilter(std::shared_ptr chain) : LoadFilter(chain), lzma(_lzma_init) { /* Allow saves up to 256 MB uncompressed */ if (lzma_auto_decoder(&this->lzma, 1 << 28, 0) != LZMA_OK) SlError(STR_GAME_SAVELOAD_ERROR_BROKEN_INTERNAL_ERROR, "cannot initialize decompressor"); @@ -2551,7 +2545,7 @@ struct LZMASaveFilter : SaveFilter { * @param chain The next filter in this chain. * @param compression_level The requested level of compression. */ - LZMASaveFilter(SaveFilter *chain, byte compression_level) : SaveFilter(chain), lzma(_lzma_init) + LZMASaveFilter(std::shared_ptr chain, byte compression_level) : SaveFilter(chain), lzma(_lzma_init) { if (lzma_easy_encoder(&this->lzma, compression_level, LZMA_CHECK_CRC32) != LZMA_OK) SlError(STR_GAME_SAVELOAD_ERROR_BROKEN_INTERNAL_ERROR, "cannot initialize compressor"); } @@ -2611,8 +2605,8 @@ struct SaveLoadFormat { const char *name; ///< name of the compressor/decompressor (debug-only) uint32_t tag; ///< the 4-letter tag by which it is identified in the savegame - LoadFilter *(*init_load)(LoadFilter *chain); ///< Constructor for the load filter. - SaveFilter *(*init_write)(SaveFilter *chain, byte compression); ///< Constructor for the save filter. + std::shared_ptr (*init_load)(std::shared_ptr chain); ///< Constructor for the load filter. + std::shared_ptr (*init_write)(std::shared_ptr chain, byte compression); ///< Constructor for the save filter. byte min_compression; ///< the minimum compression level of this format byte default_compression; ///< the default compression level of this format @@ -2720,13 +2714,11 @@ static inline void ClearSaveLoadState() delete _sl.dumper; _sl.dumper = nullptr; - delete _sl.sf; _sl.sf = nullptr; delete _sl.reader; _sl.reader = nullptr; - delete _sl.lf; _sl.lf = nullptr; } @@ -2839,7 +2831,7 @@ void WaitTillSaved() * @param threaded Whether to try to perform the saving asynchronously. * @return Return the result of the action. #SL_OK or #SL_ERROR */ -static SaveOrLoadResult DoSave(SaveFilter *writer, bool threaded) +static SaveOrLoadResult DoSave(std::shared_ptr writer, bool threaded) { assert(!_sl.saveinprogress); @@ -2871,7 +2863,7 @@ static SaveOrLoadResult DoSave(SaveFilter *writer, bool threaded) * @param threaded Whether to try to perform the saving asynchronously. * @return Return the result of the action. #SL_OK or #SL_ERROR */ -SaveOrLoadResult SaveWithFilter(SaveFilter *writer, bool threaded) +SaveOrLoadResult SaveWithFilter(std::shared_ptr writer, bool threaded) { try { _sl.action = SLA_SAVE; @@ -2888,7 +2880,7 @@ SaveOrLoadResult SaveWithFilter(SaveFilter *writer, bool threaded) * @param load_check Whether to perform the checking ("preview") or actually load the game. * @return Return the result of the action. #SL_OK or #SL_REINIT ("unload" the game) */ -static SaveOrLoadResult DoLoad(LoadFilter *reader, bool load_check) +static SaveOrLoadResult DoLoad(std::shared_ptr reader, bool load_check) { _sl.lf = reader; @@ -3027,7 +3019,7 @@ static SaveOrLoadResult DoLoad(LoadFilter *reader, bool load_check) * @param reader The filter to read the savegame from. * @return Return the result of the action. #SL_OK or #SL_REINIT ("unload" the game) */ -SaveOrLoadResult LoadWithFilter(LoadFilter *reader) +SaveOrLoadResult LoadWithFilter(std::shared_ptr reader) { try { _sl.action = SLA_LOAD; @@ -3114,13 +3106,13 @@ SaveOrLoadResult SaveOrLoad(const std::string &filename, SaveLoadOperation fop, Debug(desync, 1, "save: {:08x}; {:02x}; {}", TimerGameEconomy::date, TimerGameEconomy::date_fract, filename); if (!_settings_client.gui.threaded_saves) threaded = false; - return DoSave(new FileWriter(fh), threaded); + return DoSave(std::make_shared(fh), threaded); } /* LOAD game */ assert(fop == SLO_LOAD || fop == SLO_CHECK); Debug(desync, 1, "load: {}", filename); - return DoLoad(new FileReader(fh), fop == SLO_CHECK); + return DoLoad(std::make_shared(fh), fop == SLO_CHECK); } catch (...) { /* This code may be executed both for old and new save games. */ ClearSaveLoadState(); diff --git a/src/saveload/saveload.h b/src/saveload/saveload.h index 9cdecc7cd7..1dfe3f96c1 100644 --- a/src/saveload/saveload.h +++ b/src/saveload/saveload.h @@ -421,8 +421,8 @@ void DoExitSave(); void DoAutoOrNetsave(FiosNumberedSaveName &counter); -SaveOrLoadResult SaveWithFilter(struct SaveFilter *writer, bool threaded); -SaveOrLoadResult LoadWithFilter(struct LoadFilter *reader); +SaveOrLoadResult SaveWithFilter(std::shared_ptr writer, bool threaded); +SaveOrLoadResult LoadWithFilter(std::shared_ptr reader); typedef void AutolengthProc(void *arg); diff --git a/src/saveload/saveload_filter.h b/src/saveload/saveload_filter.h index 490daec872..445208f041 100644 --- a/src/saveload/saveload_filter.h +++ b/src/saveload/saveload_filter.h @@ -13,20 +13,19 @@ /** Interface for filtering a savegame till it is loaded. */ struct LoadFilter { /** Chained to the (savegame) filters. */ - LoadFilter *chain; + std::shared_ptr chain; /** * Initialise this filter. * @param chain The next filter in this chain. */ - LoadFilter(LoadFilter *chain) : chain(chain) + LoadFilter(std::shared_ptr chain) : chain(chain) { } /** Make sure the writers are properly closed. */ virtual ~LoadFilter() { - delete this->chain; } /** @@ -51,28 +50,27 @@ struct LoadFilter { * @param chain The next filter in this chain. * @tparam T The type of load filter to create. */ -template LoadFilter *CreateLoadFilter(LoadFilter *chain) +template std::shared_ptr CreateLoadFilter(std::shared_ptr chain) { - return new T(chain); + return std::make_shared(chain); } /** Interface for filtering a savegame till it is written. */ struct SaveFilter { /** Chained to the (savegame) filters. */ - SaveFilter *chain; + std::shared_ptr chain; /** * Initialise this filter. * @param chain The next filter in this chain. */ - SaveFilter(SaveFilter *chain) : chain(chain) + SaveFilter(std::shared_ptr chain) : chain(chain) { } /** Make sure the writers are properly closed. */ virtual ~SaveFilter() { - delete this->chain; } /** @@ -97,9 +95,9 @@ struct SaveFilter { * @param compression_level The requested level of compression. * @tparam T The type of save filter to create. */ -template SaveFilter *CreateSaveFilter(SaveFilter *chain, byte compression_level) +template std::shared_ptr CreateSaveFilter(std::shared_ptr chain, byte compression_level) { - return new T(chain, compression_level); + return std::make_shared(chain, compression_level); } #endif /* SAVELOAD_FILTER_H */