From e49bb21d76d0f0aec8b164b3a58a3e30c6bf5014 Mon Sep 17 00:00:00 2001 From: SamuXarick <43006711+SamuXarick@users.noreply.github.com> Date: Tue, 19 Nov 2024 21:14:55 +0000 Subject: [PATCH] Codechange: Refactor industry management with KD-Tree integration and enhanced data structures - Introduced `IndustryTypeCountCaches` struct to represent the cache of industry counts for a specific industry type. This cache includes a KD-Tree for spatial indexing and a new `IndustryTownCache` struct to represent the cache of industries for a specific town. - Modified `Industry::counts` to use `IndustryTypeCountCaches` struct. - Added `industry_kdtree.h` and `industry_kdtree.cpp` for KD-Tree operations and to avoid circular dependencies. - Introduced `RebuildIndustryKdtree` function to rebuild the KD-Tree for each industry type by iterating over all industries and categorizing them by type. - Updated `IncIndustryTypeCount` and `DecIndustryTypeCount` methods to manage KD-Tree insertions and removals directly within the `Industry` class. - Updated `CountTownIndustriesOfTypeMatchingCondition` to use the new data structures for performance. - Updated `ResetIndustryCounts` to clear the vector of `IndustryTownCache` entries. - Modified `GetIndustryTypeCount` to count industries of a type via KD-Tree for better performance. - Modified `CheckIfFarEnoughFromConflictingIndustry` to utilize KD-Tree for faster conflict checks. - Simplified `DoCreateNewIndustry` and `Industry::~Industry` by directly integrating KD-Tree updates. - Modified `GetClosestIndustry` to use `FindNearest` and `FindNearestExcept` for faster checking. --- src/CMakeLists.txt | 2 + src/industry.h | 176 ++++++++++++++++++++++++++-------- src/industry_cmd.cpp | 43 ++++----- src/industry_kdtree.cpp | 16 ++++ src/industry_kdtree.h | 25 +++++ src/misc.cpp | 2 + src/newgrf_industries.cpp | 22 +++-- src/saveload/afterload.cpp | 1 + src/saveload/industry_sl.cpp | 2 +- src/saveload/oldloader_sl.cpp | 2 +- 10 files changed, 218 insertions(+), 73 deletions(-) create mode 100644 src/industry_kdtree.cpp create mode 100644 src/industry_kdtree.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5f7847ff8a..64386c0b42 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -221,6 +221,8 @@ add_files( industry_cmd.cpp industry_cmd.h industry_gui.cpp + industry_kdtree.cpp + industry_kdtree.h industry_map.h industry_type.h industrytype.h diff --git a/src/industry.h b/src/industry.h index 5ba98e8f5b..8930058e87 100644 --- a/src/industry.h +++ b/src/industry.h @@ -18,6 +18,7 @@ #include "station_base.h" #include "timer/timer_game_calendar.h" #include "timer/timer_game_economy.h" +#include "industry_kdtree.h" #include "town.h" @@ -244,75 +245,95 @@ struct Industry : IndustryPool::PoolItem<&_industry_pool> { static Industry *GetRandom(); static void PostDestructor(size_t index); + /** + * Struct representing the cache of industries for a specific town. + */ + struct IndustryTownCache { + TownID town_id; ///< The ID of the town. + std::vector industry_ids; ///< A vector of IDs of the industries associated with the town. + }; + + /** + * Struct representing the cache of industry counts for a specific industry type. + */ + struct IndustryTypeCountCaches { + IndustryKdtree kdtree; ///< A k-d tree for spatial indexing of industries. + std::vector towns; ///< A vector of IndustryTownCache entries, each representing a town and its associated industries. + }; + /** * Increment the count of industries of the specified type in a town. * - * @param ind Pointer to the Industry to increment its type count in the town. - * @pre ind != nullptr + * This function updates the k-d tree by inserting the industry's index and adjusts the + * industry counts for the specific industry type in the corresponding town. If the industry + * is not already in the town's list, it is added in the correct sorted position. */ - static inline void IncIndustryTypeCount(const Industry *ind) + inline void IncIndustryTypeCount() { - assert(ind != nullptr); + auto &kdtree = this->counts[this->type].kdtree; + kdtree.Insert(this->index); /* Find the correct position to insert or update the town entry using lower_bound. */ - auto &type_vector = counts[ind->type]; - auto pair_it = std::ranges::lower_bound(type_vector, ind->town->index, {}, &std::pair>::first); + auto &type_vector = this->counts[this->type].towns; + auto pair_it = std::ranges::lower_bound(type_vector, this->town->index, {}, &IndustryTownCache::town_id); - if (pair_it != std::end(type_vector) && pair_it->first == ind->town->index) { + if (pair_it != std::end(type_vector) && pair_it->town_id == this->town->index) { /* Create a reference to the town's industry list. */ - auto &iid_vector = pair_it->second; + auto &iid_vector = pair_it->industry_ids; /* Ensure the town's industry list is not empty. */ assert(!std::empty(iid_vector)); /* Ensure the industry is not already in the town's industry list. */ assert(std::ranges::all_of(iid_vector, [&](auto &iid) { - return iid != ind->index; + return iid != this->index; })); /* Find the correct position to insert the industry ID in sorted order. */ - auto iid_it = std::ranges::lower_bound(iid_vector, ind->index); + auto iid_it = std::ranges::lower_bound(iid_vector, this->index); /* Add the industry ID to the town's industry list in the correct position. */ - iid_vector.insert(iid_it, ind->index); + iid_vector.insert(iid_it, this->index); } else { /* Create a new vector for the industry IDs and add the industry ID. */ std::vector iid_vector; - iid_vector.emplace_back(ind->index); + iid_vector.emplace_back(this->index); - /* Insert the new pair (town index and vector of industry IDs) into the correct position. */ - type_vector.emplace(pair_it, ind->town->index, std::move(iid_vector)); + /* Insert the new entry (town index and vector of industry IDs) into the correct position. */ + type_vector.emplace(pair_it, IndustryTownCache{ this->town->index, std::move(iid_vector) }); } } /** * Decrement the count of industries of the specified type in a town. * - * @param ind Pointer to the Industry to decrement its type count in the town. - * @pre ind != nullptr + * This function updates the k-d tree by removing the industry's index and adjusts the + * industry counts for the specific industry type in the corresponding town. If the town's + * industry list becomes empty after removal, the town entry is removed from the list. */ - static inline void DecIndustryTypeCount(const Industry *ind) + inline void DecIndustryTypeCount() { - assert(ind != nullptr); + auto &kdtree = this->counts[this->type].kdtree; + kdtree.Remove(this->index); /* Find the correct position of the town entry using lower_bound. */ - auto &type_vector = counts[ind->type]; - auto pair_it = std::ranges::lower_bound(type_vector, ind->town->index, {}, &std::pair>::first); + auto &type_vector = this->counts[this->type].towns; + auto pair_it = std::ranges::lower_bound(type_vector, this->town->index, {}, &IndustryTownCache::town_id); /* Ensure the pair was found in the array. */ - assert(pair_it != std::end(type_vector) && pair_it->first == ind->town->index); + assert(pair_it != std::end(type_vector) && pair_it->town_id == this->town->index); /* Create a reference to the town's industry list. */ - auto &iid_vector = pair_it->second; + auto &iid_vector = pair_it->industry_ids; /* Ensure the town's industry list is not empty. */ assert(!std::empty(iid_vector)); /* Find the industry ID within the town's industry list. */ - auto iid_it = std::ranges::lower_bound(iid_vector, ind->index); + auto iid_it = std::ranges::lower_bound(iid_vector, this->index); /* Ensure the industry ID was found in the list. */ - assert(iid_it != std::end(iid_vector) && *iid_it == ind->index); + assert(iid_it != std::end(iid_vector) && *iid_it == this->index); /* Erase the industry ID from the town's industry list. */ iid_vector.erase(iid_it); @@ -371,11 +392,12 @@ public: uint16_t count = 0; if (town != nullptr) { /* Find the correct position of the town entry using lower_bound. */ - auto pair_it = std::ranges::lower_bound(counts[type], town->index, {}, &std::pair>::first); + auto &type_vector = counts[type].towns; + auto pair_it = std::ranges::lower_bound(type_vector, town->index, {}, &IndustryTownCache::town_id); - if (pair_it != std::end(counts[type]) && pair_it->first == town->index) { + if (pair_it != std::end(type_vector) && pair_it->town_id == town->index) { /* Create a reference to the town's industry list. */ - auto &iid_vector = pair_it->second; + auto &iid_vector = pair_it->industry_ids; /* Ensure the town's industry list is not empty. */ assert(!std::empty(iid_vector)); @@ -383,13 +405,13 @@ public: } } else { /* Count industries in all towns. */ - for (auto &pair : counts[type]) { + for (auto &pair : counts[type].towns) { /* Create a reference to the town's industry list. */ - auto &iid_vector = pair.second; + auto &iid_vector = pair.industry_ids; /* Ensure the town's industry list is not empty. */ assert(!std::empty(iid_vector)); - count += ProcessIndustries(iid_vector, Town::Get(pair.first), return_early, func); + count += ProcessIndustries(iid_vector, Town::Get(pair.town_id), return_early, func); if (return_early && count != 0) break; } } @@ -403,9 +425,13 @@ public: * @param type The type of industry to count. * @return uint16_t The count of industries of the specified type. */ - static inline uint16_t GetIndustryTypeCount(IndustryType type) + static uint16_t GetIndustryTypeCount(IndustryType type) { - return CountTownIndustriesOfTypeMatchingCondition(type, nullptr, false, [](const Industry *) { return true; }); + uint16_t count = static_cast(counts[type].kdtree.Count()); + + /* Sanity check. Both the Kdtree and the count of industries in all towns should match. */ + assert(count == CountTownIndustriesOfTypeMatchingCondition(type, nullptr, false, [](const Industry *) { return true; })); + return count; } /** @@ -425,13 +451,84 @@ public: /** * Resets the industry counts for all industry types. * - * Clears the vector of industry counts for each industry type, + * Clears the vector of IndustryTownCache entries for each industry type, * effectively resetting the count of industries per type in all towns. */ static inline void ResetIndustryCounts() { /* Clear the vector for each industry type in the counts array. */ - std::ranges::for_each(counts, [](auto &type) { type.clear(); }); + std::ranges::for_each(counts, [](auto &type) { type.towns.clear(); }); + } + + /** + * Rebuilds the k-d tree for each industry type. + * + * Iterates over all industries, categorizes them by type, and rebuilds the + * k-d tree for each industry type using the collected industry IDs. + */ + static inline void RebuildIndustryKdtree() + { + std::array, NUM_INDUSTRYTYPES> industryids; + for (const Industry *industry : Industry::Iterate()) { + industryids[industry->type].push_back(industry->index); + } + + for (IndustryType type = 0; type < NUM_INDUSTRYTYPES; type++) { + counts[type].kdtree.Build(industryids[type].begin(), industryids[type].end()); + } + } + + /** + * Find all industries within a specified radius of a given tile. + * + * This function calculates the rectangular area around the given tile, constrained by the map boundaries, + * and finds all industries contained within this area using the k-d tree. + * + * @param tile The central tile index. + * @param type The industry type to search for. + * @param radius The search radius around the tile. + * @return std::vector A vector of industry IDs found within the specified area. + */ + static inline std::vector FindContained(TileIndex tile, IndustryType type, int radius) + { + static uint16_t x1, x2, y1, y2; + x1 = (uint16_t)std::max(0, TileX(tile) - radius); + x2 = (uint16_t)std::min(TileX(tile) + radius + 1, Map::SizeX()); + y1 = (uint16_t)std::max(0, TileY(tile) - radius); + y2 = (uint16_t)std::min(TileY(tile) + radius + 1, Map::SizeY()); + + return counts[type].kdtree.FindContained(x1, y1, x2, y2); + } + + /** + * Find the industry nearest to a given tile. + * + * This function uses the k-d tree to find the industry of the specified type + * that is closest to the given tile based on its coordinates. + * + * @param tile The tile index to search from. + * @param type The industry type to search for. + * @return IndustryID The ID of the nearest industry. + */ + static inline IndustryID FindNearest(TileIndex tile, IndustryType type) + { + return counts[type].kdtree.FindNearest(TileX(tile), TileY(tile)); + } + + /** + * Find the industry nearest to a given tile, excluding a specified industry. + * + * This function uses the k-d tree to find the industry of the specified type + * that is closest to the given tile, excluding the industry with the given ID. + * + * @param tile The tile index to search from. + * @param type The industry type to search for. + * @param iid The ID of the industry to exclude from the search. + * @return IndustryID The ID of the nearest industry, excluding the specified one. + */ + static inline IndustryID FindNearestExcept(TileIndex tile, IndustryType type, IndustryID iid) + { + return counts[type].kdtree.FindNearestExcept(TileX(tile), TileY(tile), iid); } inline const std::string &GetCachedName() const @@ -445,11 +542,14 @@ private: protected: /** - * Array containing vectors of industry types. - * Each vector corresponds to a specific IndustryType and - * contains pairs of TownIDs and their associated lists of IndustryIDs. + * Array containing data for each industry type. + * Each element corresponds to a specific IndustryType and contains: + * - An IndustryKdtree for spatial indexing. + * - A vector of IndustryTownCache entries, where each entry holds: + * - A TownID representing a town. + * - A vector of IndustryIDs associated with that town. */ - static std::array>>, NUM_INDUSTRYTYPES> counts; + static std::array counts; }; void ClearAllIndustryCachedNames(); diff --git a/src/industry_cmd.cpp b/src/industry_cmd.cpp index 62b96ad4ff..6d266e4c1e 100644 --- a/src/industry_cmd.cpp +++ b/src/industry_cmd.cpp @@ -10,6 +10,7 @@ #include "stdafx.h" #include "clear_map.h" #include "industry.h" +#include "industry_kdtree.h" #include "station_base.h" #include "landscape.h" #include "viewport_func.h" @@ -63,7 +64,7 @@ void BuildOilRig(TileIndex tile); static uint8_t _industry_sound_ctr; static TileIndex _industry_sound_tile; -std::array>>, NUM_INDUSTRYTYPES> Industry::counts; +std::array Industry::counts; IndustrySpec _industry_specs[NUM_INDUSTRYTYPES]; IndustryTileSpec _industry_tile_specs[NUM_INDUSTRYTILES]; @@ -189,7 +190,7 @@ Industry::~Industry() /* Clear the persistent storage. */ delete this->psa; - DecIndustryTypeCount(this); + this->DecIndustryTypeCount(); DeleteIndustryNews(this->index); CloseWindowById(WC_INDUSTRY_VIEW, this->index); @@ -1693,33 +1694,18 @@ static CommandCost CheckIfFarEnoughFromConflictingIndustry(TileIndex tile, Indus /* On a large map with many industries, it may be faster to check an area. */ static const int dmax = 14; - if (Industry::GetNumItems() > static_cast(dmax * dmax * 2)) { - const Industry *i = nullptr; - TileArea tile_area = TileArea(tile, 1, 1).Expand(dmax); - for (TileIndex atile : tile_area) { - if (GetTileType(atile) == MP_INDUSTRY) { - const Industry *i2 = Industry::GetByTile(atile); - if (i == i2) continue; - i = i2; - if (DistanceMax(tile, i->location.tile) > (uint)dmax) continue; - if (i->type == indspec->conflicting[0] || - i->type == indspec->conflicting[1] || - i->type == indspec->conflicting[2]) { - return_cmd_error(STR_ERROR_INDUSTRY_TOO_CLOSE); - } - } - } - return CommandCost(); - } + for (IndustryType conflicting_type : indspec->conflicting) { + if (conflicting_type >= NUM_INDUSTRYTYPES) continue; - for (const Industry *i : Industry::Iterate()) { - /* Within 14 tiles from another industry is considered close */ - if (DistanceMax(tile, i->location.tile) > 14) continue; + std::vector nearby_industries = Industry::FindContained(tile, conflicting_type, dmax); + + auto is_conflicting = [&tile](IndustryID iid) { + /* Within 14 tiles from another industry is considered close */ + return DistanceMax(tile, Industry::Get(iid)->location.tile) <= dmax; + }; /* check if there are any conflicting industry types around */ - if (i->type == indspec->conflicting[0] || - i->type == indspec->conflicting[1] || - i->type == indspec->conflicting[2]) { + if (std::ranges::any_of(nearby_industries, is_conflicting)) { return_cmd_error(STR_ERROR_INDUSTRY_TOO_CLOSE); } } @@ -1809,7 +1795,7 @@ static void DoCreateNewIndustry(Industry *i, TileIndex tile, IndustryType type, } i->town = t; - Industry::IncIndustryTypeCount(i); + i->IncIndustryTypeCount(); i->owner = OWNER_NONE; uint16_t r = Random(); @@ -2500,6 +2486,9 @@ void GenerateIndustries() PlaceInitialIndustry(it, false); } _industry_builder.Reset(); + + /* Build the industry k-d tree again to make sure it's well balanced */ + Industry::RebuildIndustryKdtree(); } /** diff --git a/src/industry_kdtree.cpp b/src/industry_kdtree.cpp new file mode 100644 index 0000000000..75883ba452 --- /dev/null +++ b/src/industry_kdtree.cpp @@ -0,0 +1,16 @@ +/* + * This file is part of OpenTTD. + * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2. + * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see . + */ + +/** @file industry_kdtree.cpp Implementation of the industry k-d tree functions */ + +#include "industry_kdtree.h" +#include "industry.h" + +uint16_t Kdtree_IndustryXYFunc::operator()(IndustryID iid, int dim) +{ + return (dim == 0) ? TileX(Industry::Get(iid)->location.tile) : TileY(Industry::Get(iid)->location.tile); +} diff --git a/src/industry_kdtree.h b/src/industry_kdtree.h new file mode 100644 index 0000000000..980ec16a92 --- /dev/null +++ b/src/industry_kdtree.h @@ -0,0 +1,25 @@ +/* + * This file is part of OpenTTD. + * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2. + * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see . + */ + +/** @file industry_kdtree.h Declarations for accessing the k-d tree of industries */ + +#ifndef INDUSTRY_KDTREE_H +#define INDUSTRY_KDTREE_H + +#include "core/kdtree.hpp" +#include "industry_type.h" + +/* Forward declaration of Industry struct */ +struct Industry; + +struct Kdtree_IndustryXYFunc { + uint16_t operator()(IndustryID iid, int dim); +}; + +using IndustryKdtree = Kdtree; + +#endif diff --git a/src/misc.cpp b/src/misc.cpp index 72701ff23c..c43ba2b28b 100644 --- a/src/misc.cpp +++ b/src/misc.cpp @@ -31,6 +31,7 @@ #include "station_kdtree.h" #include "town_kdtree.h" #include "viewport_kdtree.h" +#include "industry.h" #include "newgrf_profiling.h" #include "3rdparty/monocypher/monocypher.h" @@ -126,6 +127,7 @@ void InitializeGame(uint size_x, uint size_y, bool reset_date, bool reset_settin RebuildStationKdtree(); RebuildTownKdtree(); RebuildViewportKdtree(); + Industry::RebuildIndustryKdtree(); ResetPersistentNewGRFData(); diff --git a/src/newgrf_industries.cpp b/src/newgrf_industries.cpp index 5bc91e9091..fa7afcd64a 100644 --- a/src/newgrf_industries.cpp +++ b/src/newgrf_industries.cpp @@ -92,13 +92,20 @@ uint32_t GetIndustryIDAtOffset(TileIndex tile, const Industry *i, uint32_t cur_g static uint32_t GetClosestIndustry(TileIndex tile, IndustryType type, const Industry *current) { uint32_t best_dist = UINT32_MAX; - for (const Industry *i : Industry::Iterate()) { - if (i->type != type || i == current) continue; - best_dist = std::min(best_dist, DistanceManhattan(tile, i->location.tile)); + const size_t count = Industry::GetIndustryTypeCount(type); + if (count == 0) return best_dist; + + IndustryID iid; + if (current->index == INVALID_INDUSTRY || type != current->type) { + iid = Industry::FindNearest(tile, type); + } else { + if (count == 1) return best_dist; + iid = Industry::FindNearestExcept(tile, type, current->index); + assert(iid != current->index); } - return best_dist; + return DistanceManhattan(tile, Industry::Get(iid)->location.tile); } /** @@ -279,9 +286,12 @@ static uint32_t GetCountAndDistanceOfClosestInstance(uint8_t param_setID, uint8_ } /* Distance of nearest industry of given type */ - case 0x64: + case 0x64: { if (this->tile == INVALID_TILE) break; - return GetClosestIndustry(this->tile, MapNewGRFIndustryType(parameter, indspec->grf_prop.grffile->grfid), this->industry); + IndustryType type = MapNewGRFIndustryType(parameter, indspec->grf_prop.grffile->grfid); + if (type >= NUM_INDUSTRYTYPES) return UINT32_MAX; + return GetClosestIndustry(this->tile, type, this->industry); + } /* Get town zone and Manhattan distance of closest town */ case 0x65: { if (this->tile == INVALID_TILE) break; diff --git a/src/saveload/afterload.cpp b/src/saveload/afterload.cpp index 5df9556bba..f6c0a22be9 100644 --- a/src/saveload/afterload.cpp +++ b/src/saveload/afterload.cpp @@ -580,6 +580,7 @@ bool AfterLoadGame() /* This needs to be done even before conversion, because some conversions will destroy objects * that otherwise won't exist in the tree. */ RebuildViewportKdtree(); + Industry::RebuildIndustryKdtree(); if (IsSavegameVersionBefore(SLV_98)) _gamelog.GRFAddList(_grfconfig); diff --git a/src/saveload/industry_sl.cpp b/src/saveload/industry_sl.cpp index e57033880a..d5ce6c4d69 100644 --- a/src/saveload/industry_sl.cpp +++ b/src/saveload/industry_sl.cpp @@ -236,7 +236,7 @@ struct INDYChunkHandler : ChunkHandler { { for (Industry *i : Industry::Iterate()) { SlObject(i, _industry_desc); - Industry::IncIndustryTypeCount(i); + i->IncIndustryTypeCount(); } } }; diff --git a/src/saveload/oldloader_sl.cpp b/src/saveload/oldloader_sl.cpp index 76aaa89743..628281731f 100644 --- a/src/saveload/oldloader_sl.cpp +++ b/src/saveload/oldloader_sl.cpp @@ -874,7 +874,7 @@ static bool LoadOldIndustry(LoadgameState *ls, int num) i->random_colour = RemapTTOColour(i->random_colour); } - Industry::IncIndustryTypeCount(i); + i->IncIndustryTypeCount(); } else { delete i; }