diff --git a/API/Controller/Admin/GetUsers.cs b/API/Controller/Admin/GetUsers.cs index cfefcb0d..c68fa5f0 100644 --- a/API/Controller/Admin/GetUsers.cs +++ b/API/Controller/Admin/GetUsers.cs @@ -8,6 +8,7 @@ using OpenShock.Common.Utils; using Z.EntityFramework.Plus; using OpenShock.Common.OpenShockDb; +using OpenShock.Common.Query; namespace OpenShock.API.Controller.Admin; @@ -47,7 +48,11 @@ public async Task GetUsers( query = query.OrderBy(u => u.CreatedAt); } } - catch (ExpressionBuilder.ExpressionException e) + catch (QueryStringTokenizerException e) + { + return Problem(ExpressionError.QueryStringInvalidError(e.Message)); + } + catch (DBExpressionBuilderException e) { return Problem(ExpressionError.ExpressionExceptionError(e.Message)); } diff --git a/Common.Tests/Common.Tests.csproj b/Common.Tests/Common.Tests.csproj index fd05faf6..c3f77ff5 100644 --- a/Common.Tests/Common.Tests.csproj +++ b/Common.Tests/Common.Tests.csproj @@ -5,6 +5,7 @@ + diff --git a/Common.Tests/Geo/Alpha2CountryCodeTests.cs b/Common.Tests/Geo/Alpha2CountryCodeTests.cs index 905bea5e..aa6487be 100644 --- a/Common.Tests/Geo/Alpha2CountryCodeTests.cs +++ b/Common.Tests/Geo/Alpha2CountryCodeTests.cs @@ -1,4 +1,5 @@ using OpenShock.Common.Geo; +using TUnit.Assertions.AssertConditions.Throws; namespace OpenShock.Common.Tests.Geo; @@ -22,15 +23,13 @@ public async Task ValidCode_ShouldParse(string str, char char1, char char2) [Arguments("INVALID")] public async Task InvalidCharCount_ShouldThrow_InvalidLength(string str) { - // Act - var ex = await Assert.ThrowsAsync(() => - { - Alpha2CountryCode c = str; - return Task.CompletedTask; - }); - - // Assert - await Assert.That(ex.Message).IsEqualTo("Country code must be exactly 2 characters long (Parameter 'str')"); + // Act & Assert + await Assert.That(() => + { + Alpha2CountryCode c = str; + }) + .ThrowsExactly() + .WithMessage("Country code must be exactly 2 characters long (Parameter 'str')"); } [Test] @@ -44,15 +43,13 @@ public async Task InvalidCharCount_ShouldThrow_InvalidLength(string str) [Arguments(":D")] public async Task InvalidCharTypes_ShouldThrow(string str) { - // Act - var ex = await Assert.ThrowsAsync(() => - { - Alpha2CountryCode c = str; - return Task.CompletedTask; - }); - - // Assert - await Assert.That(ex.Message).IsEqualTo("Country code must be uppercase ASCII characters only (Parameter 'str')"); + // Act & Assert + await Assert.That(() => + { + Alpha2CountryCode c = str; + }) + .ThrowsExactly() + .WithMessage("Country code must be uppercase ASCII characters only (Parameter 'str')"); } [Test] diff --git a/Common.Tests/Query/DBExpressionBuilderTests.cs b/Common.Tests/Query/DBExpressionBuilderTests.cs new file mode 100644 index 00000000..51126414 --- /dev/null +++ b/Common.Tests/Query/DBExpressionBuilderTests.cs @@ -0,0 +1,199 @@ +using OpenShock.Common.Query; +using TUnit.Assertions.AssertConditions.Throws; +using Bogus; + +namespace OpenShock.Common.Tests.Query; + +public class DBExpressionBuilderTests +{ + public sealed class TestClass + { + public required Guid Id { get; set; } + public required string Name { get; set; } + public required int Age { get; set; } + public required uint Height { get; set; } + public required bool IsActive { get; set; } + public required DateTime CreatedAt { get; set; } + public required TestEnum Status { get; set; } + public required float Score { get; set; } + public required double Precision { get; set; } + } + + public enum TestEnum + { + Pending, + Active, + Inactive + } + + private readonly TestClass[] TestArray; + + public DBExpressionBuilderTests() + { + var faker = new Faker() + .UseSeed(12345) + .RuleFor(t => t.Id, f => Guid.CreateVersion7()) + .RuleFor(t => t.Name, f => f.Name.FullName()) + .RuleFor(t => t.Age, f => f.Random.Int(18, 99)) + .RuleFor(t => t.Height, f => f.Random.UInt()) + .RuleFor(t => t.IsActive, f => f.Random.Bool()) + .RuleFor(t => t.CreatedAt, f => f.Date.Past(10)) + .RuleFor(t => t.Status, f => f.PickRandom()) + .RuleFor(t => t.Score, f => f.Random.Float(0, 100)) + .RuleFor(t => t.Precision, f => f.Random.Double(0, 100)); + + TestArray = faker.Generate(100).ToArray(); + } + + [Test] + public async Task EmptyString_ThrowsException() + { + // Act & Assert + await Assert + .That(() => DBExpressionBuilder.GetFilterExpression("")) + .ThrowsExactly(); + } + + [Test] + public async Task IntegerBounds_ThrowsExceptionOnOverflow() + { + // Act & Assert + await Assert + .That(() => DBExpressionBuilder.GetFilterExpression("age eq 2147483648")) + .ThrowsExactly(); + } + + [Test] + public async Task UnsignedIntegerBounds_ThrowsExceptionOnNegative() + { + // Act & Assert + await Assert + .That(() => DBExpressionBuilder.GetFilterExpression("height eq -1")) + .ThrowsExactly(); + } + + [Test] + public async Task Guid_ExactMatch() + { + // Act + var testGuid = TestArray.First().Id; // Grab a Guid from the test data + var expression = DBExpressionBuilder.GetFilterExpression($"id eq {testGuid}"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.Id == testGuid); + } + + [Test] + public async Task Integer_GreaterThanOrEquals() + { + // Act + var expression = DBExpressionBuilder.GetFilterExpression("age gte 42"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.Age >= 42); + } + + [Test] + public async Task Integer_LessThanOrEquals() + { + // Act + var expression = DBExpressionBuilder.GetFilterExpression("age lte 51"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.Age <= 51); + } + + // TODO: Make enums work + /* + [Test] + public async Task Enum_ChecksValidValues() + { + // Act + var expression = DBExpressionBuilder.GetFilterExpression("status eq Active"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).HasCount().GreaterThan(0); + } + + [Test] + public async Task Enum_InvalidValue_ThrowsException() + { + // Act & Assert + await Assert + .That(() => DBExpressionBuilder.GetFilterExpression("status eq Invalid")) + .ThrowsExactly(); + } + */ + + [Test] + public async Task Boolean_TrueMatches() + { + // Act + var expression = DBExpressionBuilder.GetFilterExpression("isActive eq true"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.IsActive == true); + } + + [Test] + public async Task Boolean_FalseMatches() + { + // Act + var expression = DBExpressionBuilder.GetFilterExpression("isActive eq false"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.IsActive == false); + } + + [Test] + public async Task DateTime_ExactMatch() + { + // Act + var testDate = TestArray[20].CreatedAt; + var expression = DBExpressionBuilder.GetFilterExpression($"createdAt eq {testDate:O}"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.CreatedAt == testDate); + } + + [Test] + public async Task DateTime_LessThan() + { + // Act + var referenceDate = DateTime.UtcNow.AddMonths(-6); + var expression = DBExpressionBuilder.GetFilterExpression($"createdAt lt {referenceDate:O}"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.CreatedAt < referenceDate); + } + + [Test] + public async Task Float_GreaterThan() + { + // Act + var expression = DBExpressionBuilder.GetFilterExpression("score gt 50"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.Score > 50f); + } + + [Test] + public async Task Double_LessThan() + { + // Act + var expression = DBExpressionBuilder.GetFilterExpression("precision lt 50"); + var result = TestArray.AsQueryable().Where(expression).ToArray(); + + // Assert + await Assert.That(result).ContainsOnly(x => x.Precision < 50f); + } +} diff --git a/Common.Tests/Query/QueryStringTokenizerTests.cs b/Common.Tests/Query/QueryStringTokenizerTests.cs new file mode 100644 index 00000000..478d9d3e --- /dev/null +++ b/Common.Tests/Query/QueryStringTokenizerTests.cs @@ -0,0 +1,226 @@ +using OpenShock.Common.Query; +using TUnit.Assertions.AssertConditions.Throws; + +namespace OpenShock.Common.Tests.Query; + +public class QueryStringTokenizerTests +{ + [Test] + public async Task EmptyString_ReturnsEmpty() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens(""); + + // Assert + await Assert.That(result).IsEmpty(); + } + + [Test] + public async Task WhiteSpaceString_ReturnsEmpty() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens(" \r\n\t"); + + // Assert + await Assert.That(result).IsEmpty(); + } + + [Test] + public async Task QuotedNewLine_ReturnsNewLine() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("'\n'"); + + // Assert + await Assert.That(result).IsEquivalentTo(["\n"]); + } + + [Test] + public async Task SimpleString_ReturnsMatching() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("testing"); + + // Assert + await Assert.That(result).IsEquivalentTo(["testing"]); + } + + [Test] + public async Task NormalUsage_Succeeds() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("username == 'morgan freeman' and age >= 35 and email ilike morgan*freeman@*.com"); + + // Assert + await Assert.That(result).IsEquivalentTo(["username", "==", "morgan freeman", "and", "age", ">=", "35", "and", "email", "ilike", "morgan*freeman@*.com"]); + } + + [Test] + public async Task SurroundingWhitespace_Ignored() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens(" hello "); + + // Assert + await Assert.That(result).IsEquivalentTo(["hello"]); + } + + [Test] + public async Task SpaceSeperatedString_ReturnsMatching() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("testing tokenizer"); + + // Assert + await Assert.That(result).IsEquivalentTo(["testing", "tokenizer"]); + } + + [Test] + public async Task MultiSpaceSeperatedString_ReturnsMatching() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("testing \r \t \n tokenizer"); + + // Assert + await Assert.That(result).IsEquivalentTo(["testing", "tokenizer"]); + } + + [Test] + public async Task UnmatchedQuote_ThrowsException() + { + // Act & Assert + await Assert + .That(() => QueryStringTokenizer.ParseQueryTokens("'hello world")) + .ThrowsExactly(); + } + + [Test] + public async Task EmptyQuotedString_ParsesAsEmpty() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("''"); + + // Assert + await Assert.That(result).IsEquivalentTo([string.Empty]); + } + + [Test] + public async Task QuotedString_ReturnsMatching() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("'testing tokenizer'"); + + // Assert + await Assert.That(result).IsEquivalentTo(["testing tokenizer"]); + } + + [Test] + public async Task MixedQuotedAndUnquotedWords_ParsesCorrectly() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("this 'is a test' string"); + + // Assert + await Assert.That(result).IsEquivalentTo(["this", "is a test", "string"]); + } + + [Test] + public async Task EscapedQuoteInsideQuotedString_ParsesCorrectly() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("'This isn\\'t a bug'"); + + // Assert + await Assert.That(result).IsEquivalentTo(["This isn't a bug"]); + } + + [Test] + public async Task EscapeAtEndOfQuotedString_ThrowsException() + { + // Act & Assert + await Assert + .That(() => QueryStringTokenizer.ParseQueryTokens("'hello world\\'")) + .ThrowsExactly(); + } + + [Test] + public async Task DoubleEscapedBackslash_ParsesCorrectly() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("'This has a backslash: \\\\'"); + + // Assert + await Assert.That(result).IsEquivalentTo(["This has a backslash: \\"]); + } + + [Test] + public async Task QuoteInsideUnquotedString_ThrowsException() + { + // Act & Assert + await Assert + .That(() => QueryStringTokenizer.ParseQueryTokens("This won't work")) + .ThrowsExactly(); + } + + [Test] + public async Task UnquotedEscapeCharacter_ThrowsException() + { + // Act & Assert + await Assert + .That(() => QueryStringTokenizer.ParseQueryTokens("hello \\ world")) + .ThrowsExactly(); + } + + [Test] + public async Task OnlyEscapeCharacter_ThrowsException() + { + // Act & Assert + await Assert + .That(() => QueryStringTokenizer.ParseQueryTokens("\\")) + .ThrowsExactly(); + } + + [Test] + public async Task EmbeddedEscapedNewline_ParsesCorrectly() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("'hello\\nworld'"); + + // Assert + await Assert.That(result).IsEquivalentTo(["hello\nworld"]); + } + + [Test] + public async Task ConsecutiveQuotedStrings_ParsesSeparately() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens("'hello' 'world'"); + + // Assert + await Assert.That(result).IsEquivalentTo(["hello", "world"]); + } + + [Test] + public async Task EmptyInputWithWhitespace_ReturnsEmpty() + { + // Act + var result = QueryStringTokenizer.ParseQueryTokens(" "); + + // Assert + await Assert.That(result).IsEmpty(); + } + + [Test] + [Arguments("'\\ '")] // Escape followed by space + [Arguments("'hello \\q'")] // Invalid escape character + [Arguments("'\\x'")] // Undefined escape sequence + [Arguments("'test \\u1234'")] // Unicode escape not supported + [Arguments("'hello \\'")] // Dangling backslash at end of quoted string + public async Task InvalidEscapeCharacters_ThrowsException(string invalidString) + { + // Act & Assert + await Assert + .That(() => QueryStringTokenizer.ParseQueryTokens(invalidString)) + .ThrowsExactly(); + } +} \ No newline at end of file diff --git a/Common/Errors/ExpressionError.cs b/Common/Errors/ExpressionError.cs index e9d8b1f5..e5d83b50 100644 --- a/Common/Errors/ExpressionError.cs +++ b/Common/Errors/ExpressionError.cs @@ -5,5 +5,6 @@ namespace OpenShock.Common.Errors; public static class ExpressionError { + public static OpenShockProblem QueryStringInvalidError(string details) => new OpenShockProblem("ExpressionError", "Query string is invalid", HttpStatusCode.BadRequest, details); public static OpenShockProblem ExpressionExceptionError(string details) => new OpenShockProblem("ExpressionError", "An error occured while processing the expression", HttpStatusCode.BadRequest, details); } \ No newline at end of file diff --git a/Common/Extensions/IQueryableExtensions.cs b/Common/Extensions/IQueryableExtensions.cs index 94fd771e..06ea4d96 100644 --- a/Common/Extensions/IQueryableExtensions.cs +++ b/Common/Extensions/IQueryableExtensions.cs @@ -1,5 +1,6 @@ using System.Linq.Expressions; using OpenShock.Common.Utils; +using OpenShock.Common.Query; namespace OpenShock.Common.Extensions; @@ -7,14 +8,9 @@ public static class IQueryableExtensions { public static IQueryable ApplyFilter(this IQueryable query, string filterQuery) where T : class { - var filter = ExpressionBuilder.GetFilterExpression(filterQuery); - - if (filter != null) - { - query = query.Where(filter); - } + if (string.IsNullOrWhiteSpace(filterQuery)) return query; - return query; + return query.Where(DBExpressionBuilder.GetFilterExpression(filterQuery)); } public static IOrderedQueryable ApplyOrderBy(this IQueryable query, string orderbyQuery) where T : class @@ -28,9 +24,7 @@ public static IOrderedQueryable ApplyOrderBy(this IQueryable query, str var entityType = typeof(T); - var memberInfo = ExpressionBuilder.GetPropertyOrField(entityType, propOrFieldName); - if (memberInfo == null) - throw new ExpressionBuilder.ExpressionException($"'{propOrFieldName}' is not a valid property"); + var (memberInfo, memberType) = DBExpressionBuilderUtils.GetPropertyOrField(entityType, propOrFieldName); var parameterExpr = Expression.Parameter(entityType, "x"); var memberExpr = Expression.MakeMemberAccess(parameterExpr, memberInfo); @@ -42,10 +36,6 @@ public static IOrderedQueryable ApplyOrderBy(this IQueryable query, str "desc" => "OrderByDescending", _ => throw new ArgumentException(), }; - - var memberType = ExpressionBuilder.GetPropertyOrFieldType(memberInfo); - if (memberType == null) - throw new ExpressionBuilder.ExpressionException("Unknown error occured"); // Get the appropriate Queryable method (OrderBy or OrderByDescending) var method = typeof(Queryable).GetMethods() diff --git a/Common/Query/DBExpressionBuilder.cs b/Common/Query/DBExpressionBuilder.cs new file mode 100644 index 00000000..a7f075e1 --- /dev/null +++ b/Common/Query/DBExpressionBuilder.cs @@ -0,0 +1,113 @@ +using System.Linq.Expressions; +using System.Text.RegularExpressions; + +namespace OpenShock.Common.Query; + +public sealed class DBExpressionBuilderException : Exception +{ + public DBExpressionBuilderException(string message) : base(message) { } +} + +public static partial class DBExpressionBuilder +{ + [GeneratedRegex(@"^[A-Za-z][A-Za-z0-9]*$")] + private static partial Regex ValidMemberNameRegex(); + + private static Expression CreateMemberCompareExpression(Type entityType, ParameterExpression parameterExpr, string propOrFieldName, string operation, string value) where T : class + { + var (memberInfo, memberType) = DBExpressionBuilderUtils.GetPropertyOrField(entityType, propOrFieldName); + + var memberExpr = Expression.MakeMemberAccess(parameterExpr, memberInfo); + + Expression? resultExpr = operation switch + { + "like" => DBExpressionBuilderUtils.BuildEfFunctionsLikeExpression(memberType, memberExpr, value), + "ilike" => DBExpressionBuilderUtils.BuildEfFunctionsCollatedILikeExpression(memberType, memberExpr, value), + "==" or "eq" => DBExpressionBuilderUtils.BuildEqualExpression(memberType, memberExpr, value), + "!=" or "neq" => DBExpressionBuilderUtils.BuildNotEqualExpression(memberType, memberExpr, value), + "<" or "lt" => DBExpressionBuilderUtils.BuildLessThanExpression(memberType, memberExpr, value), + ">" or "gt" => DBExpressionBuilderUtils.BuildGreaterThanExpression(memberType, memberExpr, value), + "<=" or "lte" => DBExpressionBuilderUtils.BuildLessThanOrEqualExpression(memberType, memberExpr, value), + ">=" or "gte" => DBExpressionBuilderUtils.BuildGreaterThanOrEqualExpression(memberType, memberExpr, value), + _ => throw new DBExpressionBuilderException($"'{operation}' is not a supported operation type.") + }; + + return resultExpr ?? throw new DBExpressionBuilderException($"Operation {operation} is not supported for {memberType}"); + } + + + private sealed record ParsedFilter(string MemberName, string Operation, string Value); + private enum ExpectedToken + { + Member, + Operation, + Value, + AndOrEnd + } + private static IEnumerable ParseFilters(string query) + { + var member = string.Empty; + var operation = string.Empty; + var expectedToken = ExpectedToken.Member; + foreach (var word in QueryStringTokenizer.ParseQueryTokens(query)) + { + switch (expectedToken) + { + case ExpectedToken.Member: + member = word; + expectedToken = ExpectedToken.Operation; + break; + case ExpectedToken.Operation: + operation = word; + expectedToken = ExpectedToken.Value; + break; + case ExpectedToken.Value: + if (!ValidMemberNameRegex().IsMatch(member)) + throw new DBExpressionBuilderException("Invalid filter string!"); + + if (string.IsNullOrEmpty(operation)) + throw new DBExpressionBuilderException("Invalid filter string!"); + + yield return new ParsedFilter(member, operation, word); + + member = string.Empty; + operation = string.Empty; + expectedToken = ExpectedToken.AndOrEnd; + break; + case ExpectedToken.AndOrEnd: + if (word != "and") throw new DBExpressionBuilderException("Only and is supported atm!"); + expectedToken = ExpectedToken.Member; + break; + default: + throw new DBExpressionBuilderException("Unexpected state!"); + } + } + + if (expectedToken != ExpectedToken.AndOrEnd) + throw new DBExpressionBuilderException("Unexpected end of query"); + } + + public static Expression> GetFilterExpression(string filterQuery) where T : class + { + Expression? completeExpr = null; + + var entityType = typeof(T); + var parameterExpr = Expression.Parameter(entityType, "x"); + + foreach (var filter in ParseFilters(filterQuery)) + { + var memberExpr = CreateMemberCompareExpression(entityType, parameterExpr, filter.MemberName, filter.Operation, filter.Value); + + if (completeExpr == null) + { + completeExpr = memberExpr; + } + else + { + completeExpr = Expression.And(completeExpr, memberExpr); + } + } + + return Expression.Lambda>(completeExpr ?? Expression.Constant(true), parameterExpr); + } +} diff --git a/Common/Query/DBExpressionBuilderUtils.cs b/Common/Query/DBExpressionBuilderUtils.cs new file mode 100644 index 00000000..203fed15 --- /dev/null +++ b/Common/Query/DBExpressionBuilderUtils.cs @@ -0,0 +1,154 @@ +using Microsoft.EntityFrameworkCore; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.Serialization; + +namespace OpenShock.Common.Query; + +public static class DBExpressionBuilderUtils +{ + private static readonly MethodInfo EfFunctionsCollateMethodInfo = typeof(RelationalDbFunctionsExtensions).GetMethod("Collate")?.MakeGenericMethod(typeof(string)) ?? throw new MissingMethodException("EF.Functions", "Collate(string,string)"); + private static readonly MethodInfo EfFunctionsLikeMethodInfo = typeof(NpgsqlDbFunctionsExtensions).GetMethod("Like", [typeof(DbFunctions), typeof(string), typeof(string)]) ?? throw new MissingMethodException("EF.Functions", "Like(string,string)"); + private static readonly MethodInfo EfFunctionsILikeMethodInfo = typeof(NpgsqlDbFunctionsExtensions).GetMethod("ILike", [typeof(DbFunctions), typeof(string), typeof(string)]) ?? throw new MissingMethodException("EF.Functions", "ILike(string,string)"); + private static readonly MethodInfo StringEqualsMethodInfo = typeof(string).GetMethod("Equals", [typeof(string)]) ?? throw new MissingMethodException("string", "Equals(string,StringComparison)"); + private static readonly MethodInfo StringStartsWithMethodInfo = typeof(string).GetMethod("StartsWith", [typeof(string)]) ?? throw new MissingMethodException("string", "StartsWith(string)"); + private static readonly MethodInfo StringEndsWithMethodInfo = typeof(string).GetMethod("EndsWith", [typeof(string)]) ?? throw new MissingMethodException("string", "EndsWith(string)"); + private static readonly MethodInfo StringContainsMethodInfo = typeof(string).GetMethod("Contains", [typeof(string)]) ?? throw new MissingMethodException("string","Contains(string)"); + + /// + /// To not let whoever's requesting to explore hidden data structures, we return same exception for all errors here + /// + /// + /// + /// + /// + public static (MemberInfo, Type) GetPropertyOrField(Type type, string propOrFieldName) + { + var memberInfo = type.GetMember(propOrFieldName, BindingFlags.Public | BindingFlags.Instance | BindingFlags.GetProperty | BindingFlags.GetField | BindingFlags.IgnoreCase).SingleOrDefault(); + if (memberInfo == null) + throw new DBExpressionBuilderException($"'{propOrFieldName}' is not a valid property of type {type.Name}"); + + var isIgnored = memberInfo.GetCustomAttributes(typeof(IgnoreDataMemberAttribute), true).Any(); + if (isIgnored) + throw new DBExpressionBuilderException($"'{propOrFieldName}' is not a valid property of type {type.Name}"); + + var memberType = memberInfo switch + { + PropertyInfo prop => prop.PropertyType, + FieldInfo field => field.FieldType, + _ => throw new DBExpressionBuilderException($"'{propOrFieldName}' is not a valid property of type {type.Name}") + }; + + return (memberInfo, memberType); + } + + private static ConstantExpression GetConstant(Type type, string value) + { + if (type.IsEnum) + { + //Currently this causes a really weird bug which persists across subsequent requests + /* + var enumValue = Enum.Parse(type, value, ignoreCase: true); + return Expression.Constant(enumValue, type); + */ + + throw new NotImplementedException(); + } + + static object? HandleObject(Type type, string value) + { + if (type == typeof(Guid)) + { + return Guid.Parse(value); + } + + throw new NotImplementedException(); + } + + static object? HandleUnknown(Type type, string value) + { + + throw new NotImplementedException(); + } + + return Expression.Constant(Type.GetTypeCode(type) switch + { + TypeCode.Empty => throw new NotImplementedException(), + TypeCode.Object => HandleObject(type, value), + TypeCode.DBNull => throw new NotImplementedException(), + TypeCode.Boolean => Boolean.Parse(value), + TypeCode.Char => Char.Parse(value), + TypeCode.SByte => SByte.Parse(value), + TypeCode.Byte => Byte.Parse(value), + TypeCode.Int16 => Int16.Parse(value), + TypeCode.UInt16 => UInt16.Parse(value), + TypeCode.Int32 => Int32.Parse(value), + TypeCode.UInt32 => UInt32.Parse(value), + TypeCode.Int64 => Int64.Parse(value), + TypeCode.UInt64 => UInt64.Parse(value), + TypeCode.Single => Single.Parse(value), + TypeCode.Double => Double.Parse(value), + TypeCode.Decimal => Decimal.Parse(value), + TypeCode.DateTime => DateTime.Parse(value), + TypeCode.String => value, + _ => HandleUnknown(type, value), + }); + } + + public static MethodCallExpression? BuildEfFunctionsLikeExpression(Type memberType, Expression memberExpr, string value) + { + if (memberType != typeof(string)) return null; + + var valueConstant = Expression.Constant(value, typeof(string)); + var efFunctionsConstant = Expression.Constant(EF.Functions, typeof(DbFunctions)); + + return Expression.Call(null, EfFunctionsLikeMethodInfo, efFunctionsConstant, memberExpr, valueConstant); + } + + public static MethodCallExpression? BuildEfFunctionsCollatedILikeExpression(Type memberType, Expression memberExpr, string value) + { + if (memberType != typeof(string)) return null; + + var valueConstant = Expression.Constant(value, typeof(string)); + var defaultStrConstant = Expression.Constant("default", typeof(string)); + var efFunctionsConstant = Expression.Constant(EF.Functions, typeof(DbFunctions)); + + var collated = Expression.Call(null, EfFunctionsCollateMethodInfo, efFunctionsConstant, memberExpr, defaultStrConstant); + + return Expression.Call(null, EfFunctionsILikeMethodInfo, efFunctionsConstant, collated, valueConstant); + } + + public static BinaryExpression BuildEqualExpression(Type memberType, Expression memberExpr, string value) + { + return Expression.Equal(memberExpr, GetConstant(memberType, value)); + } + + public static BinaryExpression BuildNotEqualExpression(Type memberType, Expression memberExpr, string value) + { + return Expression.NotEqual(memberExpr, GetConstant(memberType, value)); + } + + public static BinaryExpression? BuildLessThanExpression(Type memberType, Expression memberExpr, string value) + { + if (memberType is { IsPrimitive: false, IsEnum: false } && Type.GetTypeCode(memberType) != TypeCode.DateTime) return null; + return Expression.LessThan(memberExpr, GetConstant(memberType, value)); + } + + public static BinaryExpression? BuildGreaterThanExpression(Type memberType, Expression memberExpr, string value) + { + if (memberType is { IsPrimitive: false, IsEnum: false } && Type.GetTypeCode(memberType) != TypeCode.DateTime) return null; + return Expression.GreaterThan(memberExpr, GetConstant(memberType, value)); + } + + public static BinaryExpression? BuildLessThanOrEqualExpression(Type memberType, Expression memberExpr, string value) + { + if (memberType is { IsPrimitive: false, IsEnum: false } && Type.GetTypeCode(memberType) != TypeCode.DateTime) return null; + return Expression.LessThanOrEqual(memberExpr, GetConstant(memberType, value)); + } + + public static BinaryExpression? BuildGreaterThanOrEqualExpression(Type memberType, Expression memberExpr, string value) + { + if (memberType is { IsPrimitive: false, IsEnum: false } && Type.GetTypeCode(memberType) != TypeCode.DateTime) return null; + return Expression.GreaterThanOrEqual(memberExpr, GetConstant(memberType, value)); + } +} diff --git a/Common/Query/QueryStringTokenizer.cs b/Common/Query/QueryStringTokenizer.cs new file mode 100644 index 00000000..9ce5bde7 --- /dev/null +++ b/Common/Query/QueryStringTokenizer.cs @@ -0,0 +1,134 @@ +using System.Buffers; +using System.Text; + +namespace OpenShock.Common.Query; + +public sealed class QueryStringTokenizerException : Exception +{ + public QueryStringTokenizerException(string message) : base(message) { } +} + +public static class QueryStringTokenizer +{ + private const char QueryQuoteChar = '\''; + private const char QueryEscapeChar = '\\'; + + // In unquoted strings, search for quotes and escapes. If these are found we should fail the parsing. + private static readonly SearchValues unquotedSearchValues = SearchValues.Create(' ', '\r', '\n', '\t', QueryQuoteChar, QueryEscapeChar); + + /// + /// Parses a query string into a list of words, handling spaces, quoted strings, and escape sequences. + /// + /// The input query as a . + /// A list of parsed words from the query. + /// + /// Thrown when the query contains an invalid escape sequence, an unclosed quoted string, or other syntax errors. + /// + /// + /// + /// var result = ParseQueryWords("hello world"); + /// result will contain: ["hello", "world"] + /// + /// var result = ParseQueryWords("'hello world'"); + /// result will contain: ["hello world"] + /// + /// var result = ParseQueryWords("this 'isn\'t invalid'"); + /// result will contain: ["this", "isn't invalid"] + /// + /// + public static List ParseQueryTokens(ReadOnlySpan query) + { + query = query.Trim(); + + List tokens = []; + + while (!query.IsEmpty) + { + int i; + if (query[0] != QueryQuoteChar) + { + i = query.IndexOfAny(unquotedSearchValues); + if (i < 0) + { + // End of query + tokens.Add(query.ToString()); + break; + } + + // Error on non-whitespace syntax character + if (!char.IsWhiteSpace(query[i])) + throw new QueryStringTokenizerException("Invalid unquoted string in query."); + + // Next space seperated part + tokens.Add(query[..i].ToString()); + + query = query[(i + 1)..].TrimStart(); + continue; + } + + // Skip quote char + query = query[1..]; + + // Find next quote or escape char + i = query.IndexOfAny(QueryQuoteChar, QueryEscapeChar); + if (i < 0) + throw new QueryStringTokenizerException("Closing quote not found."); + + // Fast path: string contains no escapes + if (query[i] == QueryQuoteChar) + { + // If i is 1 then its empty quotes + tokens.Add(i == 0 ? string.Empty : query[..i].ToString()); + query = query[(i + 1)..].TrimStart(); + continue; + } + + var sb = new StringBuilder(); + + // Parse escaped string + while (true) + { + // Add everything before escape + if (i > 0) sb.Append(query[..i]); + + // Needs space for escape sequence and end of string + if (i + 2 >= query.Length) + throw new QueryStringTokenizerException("Invalid end of query."); + + // Add escape + sb.Append(query[i + 1] switch + { + QueryQuoteChar => QueryQuoteChar, + QueryEscapeChar => QueryEscapeChar, + 'n' => '\n', + 'r' => '\r', + 't' => '\t', + _ => throw new QueryStringTokenizerException("Invalid escape sequence.") + }); + + // Skip past escape sequence + query = query[(i + 2)..]; + + i = query.IndexOfAny(QueryQuoteChar, QueryEscapeChar); + if (i < 0) + throw new QueryStringTokenizerException("Closing quote not found."); + + if (query[i] == QueryQuoteChar) + { + // Add everything before quote + if (i > 0) sb.Append(query[..i]); + + // Finish off string + tokens.Add(sb.ToString()); + + query = query[(i + 1)..].TrimStart(); + break; + } + + // Loop continues at escape found + } + } + + return tokens; + } +} diff --git a/Common/Utils/ExpressionBuilder.cs b/Common/Utils/ExpressionBuilder.cs deleted file mode 100644 index 2a628219..00000000 --- a/Common/Utils/ExpressionBuilder.cs +++ /dev/null @@ -1,259 +0,0 @@ -using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.Serialization; -using System.Text.RegularExpressions; -using Microsoft.EntityFrameworkCore; - -namespace OpenShock.Common.Utils; - -public static partial class ExpressionBuilder -{ - public sealed class ExpressionException : Exception - { - public ExpressionException(string message) : base(message) { } - } - - [GeneratedRegex(@"^[A-Za-z][A-Za-z0-9]*$")] - private static partial Regex ValidMemberNameRegex(); - - private static readonly MethodInfo EfFunctionsCollateMethodInfo = typeof(RelationalDbFunctionsExtensions).GetMethod("Collate")?.MakeGenericMethod(typeof(string)) ?? throw new ExpressionException("EF.Functions.Collate(string,string) not found"); - private static readonly MethodInfo EfFunctionsILikeMethodInfo = typeof(NpgsqlDbFunctionsExtensions).GetMethod("ILike", [typeof(DbFunctions), typeof(string), typeof(string) ]) ?? throw new ExpressionException("EF.Functions.ILike(string,string) not found"); - private static readonly MethodInfo StringEqualsMethodInfo = typeof(string).GetMethod("Equals", [typeof(string)]) ?? throw new ExpressionException("string.Equals(string,StringComparison) method not found"); - private static readonly MethodInfo StringStartsWithMethodInfo = typeof(string).GetMethod("StartsWith", [typeof(string)]) ?? throw new ExpressionException("string.StartsWith(string) method not found"); - private static readonly MethodInfo StringEndsWithMethodInfo = typeof(string).GetMethod("EndsWith", [typeof(string)]) ?? throw new ExpressionException("string.EndsWith(string) method not found"); - private static readonly MethodInfo StringContainsMethodInfo = typeof(string).GetMethod("Contains", [typeof(string)]) ?? throw new ExpressionException("string.Contains(string) method not found"); - - public static MemberInfo? GetPropertyOrField(Type type, string propOrFieldName) - { - var member = type.GetMember(propOrFieldName, BindingFlags.Public | BindingFlags.Instance | BindingFlags.GetProperty | BindingFlags.GetField | BindingFlags.IgnoreCase).SingleOrDefault(); - if (member == null) - return null; - - var isIgnored = member.GetCustomAttributes(typeof(IgnoreDataMemberAttribute), true).Any(); - if (isIgnored) - return null; - - return member; - } - - public static Type? GetPropertyOrFieldType(MemberInfo propOrField) - { - return propOrField switch - { - PropertyInfo prop => prop.PropertyType, - FieldInfo field => field.FieldType, - _ => null - }; - } - - private static ConstantExpression GetConstant(Type type, string value) - { - /* Currently this causes a really weird bug which persists across subsequent requests - if (type.IsEnum) - { - var enumValue = Enum.Parse(type, value, ignoreCase: true); - return Expression.Constant(enumValue, type); - } - */ - - return Expression.Constant(value, type); - } - - private static Expression BuildEfFunctionsCollatedILikeExpression(Type memberType, Expression memberExpr, string value) - { - if (memberType != typeof(string)) - throw new ExpressionException($"Operation ILIKE is not supported for {memberType}"); - - var valueConstant = Expression.Constant(value, typeof(string)); - var defaultStrConstant = Expression.Constant("default", typeof(string)); - var efFunctionsConstant = Expression.Constant(EF.Functions, typeof(DbFunctions)); - - var collated = Expression.Call(null, EfFunctionsCollateMethodInfo, efFunctionsConstant, memberExpr, defaultStrConstant); - - return Expression.Call(null, EfFunctionsILikeMethodInfo, efFunctionsConstant, collated, valueConstant); - } - - private static Expression BuildEqualExpression(Type memberType, Expression memberExpr, string value) - { - return Expression.Equal(memberExpr, GetConstant(memberType, value)); - } - - private static Expression BuildNotEqualExpression(Type memberType, Expression memberExpr, string value) - { - return Expression.NotEqual(memberExpr, GetConstant(memberType, value)); - } - - private static Expression BuildLessThanExpression(Type memberType, Expression memberExpr, string value) - { - if (memberType is { IsPrimitive: false, IsEnum: false }) - throw new ExpressionException($"Operation < is not supported for {memberType}"); - return Expression.LessThan(memberExpr, GetConstant(memberType, value)); - } - - private static Expression BuildGreaterThanExpression(Type memberType, Expression memberExpr, string value) - { - if (memberType is { IsPrimitive: false, IsEnum: false }) - throw new ExpressionException($"Operation > is not supported for {memberType}"); - return Expression.GreaterThan(memberExpr, GetConstant(memberType, value)); - } - - private static Expression BuildLessThanOrEqualExpression(Type memberType, Expression memberExpr, string value) - { - if (memberType is { IsPrimitive: false, IsEnum: false }) - throw new ExpressionException($"Operation <= is not supported for {memberType}"); - return Expression.LessThan(memberExpr, GetConstant(memberType, value)); - } - - private static Expression BuildGreaterThanOrEqualExpression(Type memberType, Expression memberExpr, string value) - { - if (memberType is { IsPrimitive: false, IsEnum: false }) - throw new ExpressionException($"Operation >= is not supported for {memberType}"); - return Expression.GreaterThan(memberExpr, GetConstant(memberType, value)); - } - - private static Expression CreateMemberCompareExpression(Type entityType, ParameterExpression parameterExpr, string propOrFieldName, string operation, string value) where T : class - { - var memberInfo = GetPropertyOrField(entityType, propOrFieldName); - if (memberInfo == null) - throw new ExpressionException($"'{propOrFieldName}' is not a valid property"); - - var memberType = GetPropertyOrFieldType(memberInfo); - if (memberType == null) - throw new ExpressionException("Unknown error occured"); - - var memberExpr = Expression.MakeMemberAccess(parameterExpr, memberInfo); - - return operation switch - { - "like" => BuildEfFunctionsCollatedILikeExpression(memberType, memberExpr, value), - "==" or "eq" => BuildEqualExpression(memberType, memberExpr, value), - "!=" or "neq" => BuildNotEqualExpression(memberType, memberExpr, value), - "<" or "lt" => BuildLessThanExpression(memberType, memberExpr, value), - ">" or "gt" => BuildGreaterThanExpression(memberType, memberExpr, value), - "<=" or "lte" => BuildLessThanOrEqualExpression(memberType, memberExpr, value), - ">=" or "gte" => BuildGreaterThanOrEqualExpression(memberType, memberExpr, value), - _ => throw new ExpressionException($"'{operation}' is not a supported operation type.") - }; - } - - private static List GetFilterWords(ReadOnlySpan query) - { - query = query.Trim(); - - List words = []; - while (!query.IsEmpty) - { - int index; - if (query[0] == '\'') - { - // Remove quote - query = query[1..]; - - // Look for next quote - index = query.IndexOf('\''); - - if (index < 0) - { - // No more quotes, throw error - throw new ExpressionException("Invalid query string, unterminated quote found."); - } - } - else - { - // Look for space - index = query.IndexOf(' '); - if (index < 0) - { - // No more spaces, return last word - words.Add(query.ToString()); - break; - } - } - - // Return next word - words.Add(query[..index].ToString()); - - // Remove word and spaces behind - query = query[(index + 1)..].TrimStart(' '); - } - - return words; - } - - private sealed record ParsedFilter(string MemberName, string Operation, string Value); - private enum ExpectedToken - { - Member, - Operation, - Value, - AndOrEnd - } - private static IEnumerable ParseFilters(string query) - { - var member = string.Empty; - var operation = string.Empty; - var expectedToken = ExpectedToken.Member; - foreach (var word in GetFilterWords(query)) - { - switch (expectedToken) - { - case ExpectedToken.Member: - member = word; - expectedToken = ExpectedToken.Operation; - break; - case ExpectedToken.Operation: - operation = word; - expectedToken = ExpectedToken.Value; - break; - case ExpectedToken.Value: - if (!ValidMemberNameRegex().IsMatch(member)) - throw new ExpressionException("Invalid filter string!"); - - if (string.IsNullOrEmpty(operation)) - throw new ExpressionException("Invalid filter string!"); - - yield return new ParsedFilter(member, operation, word); - - member = string.Empty; - operation = string.Empty; - expectedToken = ExpectedToken.AndOrEnd; - break; - case ExpectedToken.AndOrEnd: - if (word != "and") throw new ExpressionException("Only and is supported atm!"); - expectedToken = ExpectedToken.Member; - break; - default: - throw new ExpressionException("Unexpected state!"); - } - } - - if (expectedToken != ExpectedToken.AndOrEnd) - throw new ExpressionException("Unexpected end of query"); - } - - public static Expression>? GetFilterExpression(string filterQuery) where T : class - { - Expression? completeExpr = null; - - var entityType = typeof(T); - var parameterExpr = Expression.Parameter(entityType, "x"); - - foreach (var filter in ParseFilters(filterQuery)) - { - var memberExpr = CreateMemberCompareExpression(entityType, parameterExpr, filter.MemberName, filter.Operation, filter.Value); - - if (completeExpr == null) - { - completeExpr = memberExpr; - } - else - { - completeExpr = Expression.And(completeExpr, memberExpr); - } - } - - if (completeExpr == null) return null; - - return Expression.Lambda>(completeExpr, parameterExpr); - } -}