From b2c57123199d96179d105aec3ee14e5ac4a602ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Guilloux?= Date: Thu, 13 Feb 2025 13:46:21 +0100 Subject: [PATCH] Fix: [Script] Improve type checking of parameters (#13522) --- cmake/scripts/SquirrelExport.cmake | 13 ++++++++----- src/script/squirrel.cpp | 16 ++++++++++++++++ src/script/squirrel.hpp | 2 +- src/script/squirrel_helper.hpp | 29 +++++++++++++++++++++-------- 4 files changed, 46 insertions(+), 14 deletions(-) diff --git a/cmake/scripts/SquirrelExport.cmake b/cmake/scripts/SquirrelExport.cmake index 1f7bde1656..507ae7d04c 100644 --- a/cmake/scripts/SquirrelExport.cmake +++ b/cmake/scripts/SquirrelExport.cmake @@ -28,10 +28,10 @@ endmacro() macro(dump_class_templates NAME) string(REGEX REPLACE "^Script" "" REALNAME ${NAME}) - string(APPEND SQUIRREL_EXPORT "\n template <> struct Param<${NAME} *> { static inline ${NAME} *Get(HSQUIRRELVM vm, int index) { SQUserPointer instance; sq_getinstanceup(vm, index, &instance, nullptr); return (${NAME} *)instance; } };") - string(APPEND SQUIRREL_EXPORT "\n template <> struct Param<${NAME} &> { static inline ${NAME} &Get(HSQUIRRELVM vm, int index) { SQUserPointer instance; sq_getinstanceup(vm, index, &instance, nullptr); return *(${NAME} *)instance; } };") - string(APPEND SQUIRREL_EXPORT "\n template <> struct Param { static inline const ${NAME} *Get(HSQUIRRELVM vm, int index) { SQUserPointer instance; sq_getinstanceup(vm, index, &instance, nullptr); return (${NAME} *)instance; } };") - string(APPEND SQUIRREL_EXPORT "\n template <> struct Param { static inline const ${NAME} &Get(HSQUIRRELVM vm, int index) { SQUserPointer instance; sq_getinstanceup(vm, index, &instance, nullptr); return *(${NAME} *)instance; } };") + string(APPEND SQUIRREL_EXPORT "\n template <> struct Param<${NAME} *> { static inline ${NAME} *Get(HSQUIRRELVM vm, int index) { return static_cast<${NAME} *>(Squirrel::GetRealInstance(vm, index, \"${REALNAME}\")); } };") + string(APPEND SQUIRREL_EXPORT "\n template <> struct Param<${NAME} &> { static inline ${NAME} &Get(HSQUIRRELVM vm, int index) { return *static_cast<${NAME} *>(Squirrel::GetRealInstance(vm, index, \"${REALNAME}\")); } };") + string(APPEND SQUIRREL_EXPORT "\n template <> struct Param { static inline const ${NAME} *Get(HSQUIRRELVM vm, int index) { return static_cast<${NAME} *>(Squirrel::GetRealInstance(vm, index, \"${REALNAME}\")); } };") + string(APPEND SQUIRREL_EXPORT "\n template <> struct Param { static inline const ${NAME} &Get(HSQUIRRELVM vm, int index) { return *static_cast<${NAME} *>(Squirrel::GetRealInstance(vm, index, \"${REALNAME}\")); } };") if("${NAME}" STREQUAL "ScriptEvent") string(APPEND SQUIRREL_EXPORT "\n template <> struct Return<${NAME} *> { static inline int Set(HSQUIRRELVM vm, ${NAME} *res) { if (res == nullptr) { sq_pushnull(vm); return 1; } Squirrel::CreateClassInstanceVM(vm, \"${REALNAME}\", res, nullptr, DefSQDestructorCallback<${NAME}>, true); return 1; } };") elseif("${NAME}" STREQUAL "ScriptText") @@ -44,7 +44,10 @@ macro(dump_class_templates NAME) string(APPEND SQUIRREL_EXPORT "\n if (sq_gettype(vm, index) == OT_STRING) {") string(APPEND SQUIRREL_EXPORT "\n return new RawText(Param::Get(vm, index));") string(APPEND SQUIRREL_EXPORT "\n }") - string(APPEND SQUIRREL_EXPORT "\n return nullptr;") + string(APPEND SQUIRREL_EXPORT "\n if (sq_gettype(vm, index) == OT_NULL) {") + string(APPEND SQUIRREL_EXPORT "\n return nullptr;") + string(APPEND SQUIRREL_EXPORT "\n }") + string(APPEND SQUIRREL_EXPORT "\n throw sq_throwerror(vm, fmt::format(\"parameter {} has an invalid type ; expected: 'Text'\", index - 1));") string(APPEND SQUIRREL_EXPORT "\n }") string(APPEND SQUIRREL_EXPORT "\n };") else() diff --git a/src/script/squirrel.cpp b/src/script/squirrel.cpp index df0cf44e51..c6c3ff6bb8 100644 --- a/src/script/squirrel.cpp +++ b/src/script/squirrel.cpp @@ -510,6 +510,22 @@ bool Squirrel::CreateClassInstance(const std::string &class_name, void *real_ins return Squirrel::CreateClassInstanceVM(this->vm, class_name, real_instance, instance, nullptr); } +/* static */ SQUserPointer Squirrel::GetRealInstance(HSQUIRRELVM vm, int index, const char *tag) +{ + Squirrel *engine = static_cast(sq_getforeignptr(vm)); + std::string class_name = fmt::format("{}{}", engine->GetAPIName(), tag); + sq_pushroottable(vm); + sq_pushstring(vm, class_name); + sq_get(vm, -2); + sq_push(vm, index); + if (sq_instanceof(vm)) { + sq_pop(vm, 3); + SQUserPointer ptr = nullptr; + if (SQ_SUCCEEDED(sq_getinstanceup(vm, index, &ptr, nullptr))) return ptr; + } + throw sq_throwerror(vm, fmt::format("parameter {} has an invalid type ; expected: '{}'", index - 1, class_name)); +} + Squirrel::Squirrel(const char *APIName) : APIName(APIName), allocator(new ScriptAllocator()) { diff --git a/src/script/squirrel.hpp b/src/script/squirrel.hpp index 712f40897c..0868182bd2 100644 --- a/src/script/squirrel.hpp +++ b/src/script/squirrel.hpp @@ -194,7 +194,7 @@ public: * @note This will only work just after a function-call from within Squirrel * to your C++ function. */ - static bool GetRealInstance(HSQUIRRELVM vm, SQUserPointer *ptr) { return SQ_SUCCEEDED(sq_getinstanceup(vm, 1, ptr, nullptr)); } + static SQUserPointer GetRealInstance(HSQUIRRELVM vm, int index, const char *tag); /** * Get the Squirrel-instance pointer. diff --git a/src/script/squirrel_helper.hpp b/src/script/squirrel_helper.hpp index 8870e4c9fc..3f93d8983a 100644 --- a/src/script/squirrel_helper.hpp +++ b/src/script/squirrel_helper.hpp @@ -39,7 +39,6 @@ namespace SQConvert { template <> struct Return { static inline int Set(HSQUIRRELVM vm, bool res) { sq_pushbool (vm, res); return 1; } }; template <> struct Return { /* Do not use char *, use std::optional instead. */ }; template <> struct Return { /* Do not use const char *, use std::optional instead. */ }; - template <> struct Return { static inline int Set(HSQUIRRELVM vm, void *res) { sq_pushuserpointer(vm, res); return 1; } }; template <> struct Return { static inline int Set(HSQUIRRELVM vm, HSQOBJECT res) { sq_pushobject(vm, res); return 1; } }; template requires std::is_enum_v struct Return { @@ -86,7 +85,6 @@ namespace SQConvert { template <> struct Param { static inline TileIndex Get(HSQUIRRELVM vm, int index) { SQInteger tmp; sq_getinteger (vm, index, &tmp); return TileIndex((uint32_t)(int32_t)tmp); } }; template <> struct Param { static inline bool Get(HSQUIRRELVM vm, int index) { SQBool tmp; sq_getbool (vm, index, &tmp); return tmp != 0; } }; template <> struct Param { /* Do not use const char *, use std::string& instead. */ }; - template <> struct Param { static inline void *Get(HSQUIRRELVM vm, int index) { SQUserPointer tmp; sq_getuserpointer(vm, index, &tmp); return tmp; } }; template requires std::is_enum_v struct Param { static inline T Get(HSQUIRRELVM vm, int index) @@ -259,7 +257,9 @@ namespace SQConvert { try { /* Delegate it to a template that can handle this specific function */ - return HelperT::SQCall((Tcls *)real_instance, *(Tmethod *)ptr, vm); + auto cls_instance = static_cast(real_instance); + auto method = *static_cast(ptr); + return HelperT::SQCall(cls_instance, method, vm); } catch (SQInteger &e) { return e; } @@ -299,8 +299,14 @@ namespace SQConvert { /* Remove the userdata from the stack */ sq_pop(vm, 1); - /* Call the function, which its only param is always the VM */ - return (SQInteger)(((Tcls *)real_instance)->*(*(Tmethod *)ptr))(vm); + try { + /* Call the function, which its only param is always the VM */ + auto cls_instance = static_cast(real_instance); + auto method = *static_cast(ptr); + return static_cast((cls_instance->*method)(vm)); + } catch (SQInteger &e) { + return e; + } } /** @@ -320,7 +326,9 @@ namespace SQConvert { try { /* Delegate it to a template that can handle this specific function */ - return HelperT::SQCall((Tcls *)nullptr, *(Tmethod *)ptr, vm); + auto cls_instance = static_cast(nullptr); + auto method = *static_cast(ptr); + return HelperT::SQCall(cls_instance, method, vm); } catch (SQInteger &e) { return e; } @@ -344,8 +352,13 @@ namespace SQConvert { /* Remove the userdata from the stack */ sq_pop(vm, 1); - /* Call the function, which its only param is always the VM */ - return (SQInteger)(*(*(Tmethod *)ptr))(vm); + try { + /* Call the function, which its only param is always the VM */ + auto method = *static_cast(ptr); + return static_cast((*method)(vm)); + } catch (SQInteger &e) { + return e; + } } /**