Enhance HookVerifier to check marshaled types

This commit is contained in:
wolfcomp 2026-01-28 15:43:59 +01:00 committed by Blair
parent f01971a7d7
commit fd85a8d3bc

View file

@ -67,6 +67,7 @@ internal static class HookVerifier
} }
var passedType = typeof(T); var passedType = typeof(T);
var isAssemblyMarshaled = passedType.Assembly.GetCustomAttribute<DisableRuntimeMarshallingAttribute>() is not null;
// Directly compare delegates // Directly compare delegates
if (passedType == entry.TargetDelegateType) if (passedType == entry.TargetDelegateType)
@ -78,7 +79,7 @@ internal static class HookVerifier
var enforcedInvoke = entry.TargetDelegateType.GetMethod("Invoke")!; var enforcedInvoke = entry.TargetDelegateType.GetMethod("Invoke")!;
// Compare Return Type // Compare Return Type
var mismatch = !CheckParam(passedInvoke.ReturnType, enforcedInvoke.ReturnType); var mismatch = !CheckParam(passedInvoke.ReturnType, enforcedInvoke.ReturnType, isAssemblyMarshaled);
// Compare Parameter Count // Compare Parameter Count
var passedParams = passedInvoke.GetParameters(); var passedParams = passedInvoke.GetParameters();
@ -93,7 +94,7 @@ internal static class HookVerifier
// Compare Parameter Types // Compare Parameter Types
for (var i = 0; i < passedParams.Length; i++) for (var i = 0; i < passedParams.Length; i++)
{ {
if (!CheckParam(passedParams[i].ParameterType, enforcedParams[i].ParameterType)) if (!CheckParam(passedParams[i].ParameterType, enforcedParams[i].ParameterType, isAssemblyMarshaled))
{ {
mismatch = true; mismatch = true;
break; break;
@ -107,18 +108,18 @@ internal static class HookVerifier
} }
} }
private static bool CheckParam(Type paramLeft, Type paramRight) private static bool CheckParam(Type paramLeft, Type paramRight, bool isMarshaled)
{ {
var sameType = paramLeft == paramRight; var sameType = paramLeft == paramRight;
return sameType || SizeOf(paramLeft) == SizeOf(paramRight); return sameType || SizeOf(paramLeft, isMarshaled) == SizeOf(paramRight, false);
} }
private static int SizeOf(Type type) private static int SizeOf(Type type, bool isMarshaled)
{ {
return type switch { return type switch {
_ when type == typeof(sbyte) || type == typeof(byte) || type == typeof(bool) => 1, _ when type == typeof(sbyte) || type == typeof(byte) || (type == typeof(bool) && !isMarshaled) => 1,
_ when type == typeof(char) || type == typeof(short) || type == typeof(ushort) || type == typeof(Half) => 2, _ when type == typeof(char) || type == typeof(short) || type == typeof(ushort) || type == typeof(Half) => 2,
_ when type == typeof(int) || type == typeof(uint) || type == typeof(float) => 4, _ when type == typeof(int) || type == typeof(uint) || type == typeof(float) || (type == typeof(bool) && isMarshaled) => 4,
_ when type == typeof(long) || type == typeof(ulong) || type == typeof(double) || type.IsPointer || type.IsFunctionPointer || type.IsUnmanagedFunctionPointer || (type.Name == "Pointer`1" && type.Namespace.AsSpan().SequenceEqual(ClientStructsInteropNamespacePrefix)) || type == typeof(CStringPointer) => 8, _ when type == typeof(long) || type == typeof(ulong) || type == typeof(double) || type.IsPointer || type.IsFunctionPointer || type.IsUnmanagedFunctionPointer || (type.Name == "Pointer`1" && type.Namespace.AsSpan().SequenceEqual(ClientStructsInteropNamespacePrefix)) || type == typeof(CStringPointer) => 8,
_ when type.Name.StartsWith("FixedSizeArray") => SizeOf(type.GetGenericArguments()[0]) * int.Parse(type.Name[14..type.Name.IndexOf('`')]), _ when type.Name.StartsWith("FixedSizeArray") => SizeOf(type.GetGenericArguments()[0]) * int.Parse(type.Name[14..type.Name.IndexOf('`')]),
_ when type.GetCustomAttribute<InlineArrayAttribute>() is { Length: var length } => SizeOf(type.GetGenericArguments()[0]) * length, _ when type.GetCustomAttribute<InlineArrayAttribute>() is { Length: var length } => SizeOf(type.GetGenericArguments()[0]) * length,