diff --git a/src/newgrf.cpp b/src/newgrf.cpp index 3dc4fca19f..ee440aa464 100644 --- a/src/newgrf.cpp +++ b/src/newgrf.cpp @@ -1218,23 +1218,26 @@ struct InvokeGrfActionHandler { * XXX: We consider GRF files trusted. It would be trivial to exploit OTTD by * a crafted invalid GRF file. We should tell that to the user somehow, or * better make this more robust in the future. */ -static void DecodeSpecialSprite(uint8_t *buf, uint num, GrfLoadingStage stage) +static void DecodeSpecialSprite(ReusableBuffer &allocator, uint num, GrfLoadingStage stage) { + uint8_t *buf; auto it = _grf_line_to_action6_sprite_override.find({_cur_gps.grfconfig->ident.grfid, _cur_gps.nfo_line}); if (it == _grf_line_to_action6_sprite_override.end()) { /* No preloaded sprite to work with; read the * pseudo sprite content. */ + buf = allocator.Allocate(num); _cur_gps.file->ReadBlock(buf, num); } else { /* Use the preloaded sprite data. */ buf = it->second.data(); + assert(it->second.size() == num); GrfMsg(7, "DecodeSpecialSprite: Using preloaded pseudo sprite data"); /* Skip the real (original) content of this action. */ _cur_gps.file->SeekTo(num, SEEK_CUR); } - ByteReader br(buf, buf + num); + ByteReader br(buf, num); try { uint8_t action = br.ReadByte(); @@ -1302,7 +1305,7 @@ static void LoadNewGRFFileFromFile(GRFConfig &config, GrfLoadingStage stage, Spr _cur_gps.ClearDataForNextFile(); - ReusableBuffer buf; + ReusableBuffer allocator; while ((num = (grf_container_version >= 2 ? file.ReadDword() : file.ReadWord())) != 0) { uint8_t type = file.ReadByte(); @@ -1317,7 +1320,7 @@ static void LoadNewGRFFileFromFile(GRFConfig &config, GrfLoadingStage stage, Spr break; } - DecodeSpecialSprite(buf.Allocate(num), num, stage); + DecodeSpecialSprite(allocator, num, stage); /* Stop all processing if we are to skip the remaining sprites */ if (_cur_gps.skip_sprites == -1) break; diff --git a/src/newgrf/newgrf_bytereader.cpp b/src/newgrf/newgrf_bytereader.cpp index d09c1dfadc..a1a9585ff2 100644 --- a/src/newgrf/newgrf_bytereader.cpp +++ b/src/newgrf/newgrf_bytereader.cpp @@ -14,17 +14,6 @@ #include "../safeguards.h" -/** - * Read a single DWord (32 bits). - * @note The buffer is NOT advanced. - * @returns Value read from buffer. - */ -uint32_t ByteReader::PeekDWord() -{ - AutoRestoreBackup backup(this->data, this->data); - return this->ReadDWord(); -} - /** * Read a value of the given number of bytes. * @returns Value read from buffer. @@ -40,18 +29,3 @@ uint32_t ByteReader::ReadVarSize(uint8_t size) return 0; } } - -/** - * Read a string. - * @returns Sting read from the buffer. - */ -std::string_view ByteReader::ReadString() -{ - const char *string = reinterpret_cast(this->data); - size_t string_length = ttd_strnlen(string, this->Remaining()); - - /* Skip past the terminating NUL byte if it is present, but not more than remaining. */ - this->Skip(std::min(string_length + 1, this->Remaining())); - - return std::string_view(string, string_length); -} diff --git a/src/newgrf/newgrf_bytereader.h b/src/newgrf/newgrf_bytereader.h index b3ceec7f2b..e298524741 100644 --- a/src/newgrf/newgrf_bytereader.h +++ b/src/newgrf/newgrf_bytereader.h @@ -10,24 +10,21 @@ #ifndef NEWGRF_BYTEREADER_H #define NEWGRF_BYTEREADER_H +#include "../core/string_consumer.hpp" + class OTTDByteReaderSignal { }; /** Class to read from a NewGRF file */ class ByteReader { + StringConsumer consumer; public: - ByteReader(const uint8_t *data, const uint8_t *end) : data(data), end(end) { } + ByteReader(const uint8_t *data, size_t len) : consumer(reinterpret_cast(data), len) { } const uint8_t *ReadBytes(size_t size) { - if (this->data + size >= this->end) { - /* Put data at the end, as would happen if every byte had been individually read. */ - this->data = this->end; - throw OTTDByteReaderSignal(); - } - - const uint8_t *ret = this->data; - this->data += size; - return ret; + auto result = this->consumer.Read(size); + if (result.size() != size) throw OTTDByteReaderSignal(); + return reinterpret_cast(result.data()); } /** @@ -36,8 +33,9 @@ public: */ uint8_t ReadByte() { - if (this->data < this->end) return *this->data++; - throw OTTDByteReaderSignal(); + auto result = this->consumer.TryReadUint8(); + if (!result.has_value()) throw OTTDByteReaderSignal(); + return *result; } /** @@ -46,8 +44,9 @@ public: */ uint16_t ReadWord() { - uint16_t val = this->ReadByte(); - return val | (this->ReadByte() << 8); + auto result = this->consumer.TryReadUint16LE(); + if (!result.has_value()) throw OTTDByteReaderSignal(); + return *result; } /** @@ -66,35 +65,51 @@ public: */ uint32_t ReadDWord() { - uint32_t val = this->ReadWord(); - return val | (this->ReadWord() << 16); + auto result = this->consumer.TryReadUint32LE(); + if (!result.has_value()) throw OTTDByteReaderSignal(); + return *result; + } + + /** + * Read a single DWord (32 bits). + * @note The buffer is NOT advanced. + * @returns Value read from buffer. + */ + uint32_t PeekDWord() + { + auto result = this->consumer.PeekUint32LE(); + if (!result.has_value()) throw OTTDByteReaderSignal(); + return *result; } - uint32_t PeekDWord(); uint32_t ReadVarSize(uint8_t size); - std::string_view ReadString(); + + /** + * Read a NUL-terminated string. + * @returns String read from the buffer. + */ + std::string_view ReadString() + { + /* Terminating NUL may be missing at the end of sprite. */ + return this->consumer.ReadUntilChar('\0', StringConsumer::SKIP_ONE_SEPARATOR); + } size_t Remaining() const { - return this->end - this->data; + return this->consumer.GetBytesLeft(); } bool HasData(size_t count = 1) const { - return this->data + count <= this->end; + return count <= this->consumer.GetBytesLeft(); } void Skip(size_t len) { - this->data += len; - /* It is valid to move the buffer to exactly the end of the data, - * as there may not be any more data read. */ - if (this->data > this->end) throw OTTDByteReaderSignal(); + auto result = this->consumer.Read(len); + if (result.size() != len) throw OTTDByteReaderSignal(); } -private: - const uint8_t *data; ///< Current position within data. - const uint8_t *end; ///< Last position of data. }; #endif /* NEWGRF_BYTEREADER_H */ diff --git a/src/string_func.h b/src/string_func.h index 0952993855..09c8dc18db 100644 --- a/src/string_func.h +++ b/src/string_func.h @@ -70,20 +70,6 @@ inline bool StrEmpty(const char *s) return s == nullptr || s[0] == '\0'; } -/** - * Get the length of a string, within a limited buffer. - * - * @param str The pointer to the first element of the buffer - * @param maxlen The maximum size of the buffer - * @return The length of the string - */ -inline size_t ttd_strnlen(const char *str, size_t maxlen) -{ - const char *t; - for (t = str; static_cast(t - str) < maxlen && *t != '\0'; t++) {} - return t - str; -} - bool IsValidChar(char32_t key, CharSetFilter afilter); size_t Utf8Decode(char32_t *c, const char *s);