diff --git a/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs b/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs index a7e5e30..66c96b1 100644 --- a/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs +++ b/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs @@ -337,21 +337,20 @@ namespace EntityFrameworkCore.Jet.Query.Sql.Internal /// The select expression. protected override void GenerateTop(SelectExpression selectExpression) { - Check.NotNull(selectExpression, "selectExpression"); - if (selectExpression.Limit == null) - return; + Check.NotNull(selectExpression, nameof(selectExpression)); - Sql.Append("TOP "); - if (selectExpression.Offset == null) - Visit(selectExpression.Limit); - else + if (selectExpression.Offset != null) { - Visit(selectExpression.Limit); - Sql.Append("+"); - Visit(selectExpression.Offset); + // Jet does not support skipping rows. Use client evaluation instead. + throw new InvalidOperationException(CoreStrings.TranslationFailed(selectExpression.Offset)); } - Sql.Append(" "); + if (selectExpression.Limit != null) + { + Sql.Append("TOP "); + Visit(selectExpression.Limit); + Sql.Append(" "); + } } /// @@ -360,21 +359,7 @@ namespace EntityFrameworkCore.Jet.Query.Sql.Internal /// protected override void GenerateLimitOffset(SelectExpression selectExpression) { - // LIMIT is not natively supported by Jet. - // The System.Data.Jet tries to mitigate this by supporting a proprietary extension SKIP, but can easily - // fail, e.g. when the SKIP happens in a subquery. - - if (selectExpression.Offset == null) - return; - - // CHECK: Needed? - if (!selectExpression.Orderings.Any()) - Sql.AppendLine() - .Append("ORDER BY 0"); - - Sql.AppendLine() - .Append("SKIP "); - Visit(selectExpression.Offset); + // This has already been applied by GenerateTop(). } /// diff --git a/src/System.Data.Jet/JetCommand.cs b/src/System.Data.Jet/JetCommand.cs index 5def4e2..99e4e29 100644 --- a/src/System.Data.Jet/JetCommand.cs +++ b/src/System.Data.Jet/JetCommand.cs @@ -18,7 +18,7 @@ namespace System.Data.Jet private Guid? _lastGuid; private int? _rowCount; - private static readonly Regex _skipRegularExpression = new Regex(@"\bskip\s(?@.*)\b", RegexOptions.IgnoreCase); + private static readonly Regex _topRegularExpression = new Regex(@"(?<=(?:^|\s)select\s+top\s+)(?:\d+|(?:@\w+)|\?)(?=\s)", RegexOptions.IgnoreCase); private static readonly Regex _selectRowCountRegularExpression = new Regex(@"^\s*select\s*@@rowcount\s*;?\s*$", RegexOptions.IgnoreCase); private static readonly Regex _ifStatementRegex = new Regex(@"^\s*if\s*(?not)?\s*exists\s*\((?.+)\)\s*then\s*(?.*)$", RegexOptions.IgnoreCase); @@ -306,24 +306,15 @@ namespace System.Data.Jet private JetDataReader InternalExecuteDbDataReader(string commandText, CommandBehavior behavior) { - ParseSkipTop(commandText, out var topCount, out var skipCount, out var newCommandText); + var newCommandText = ApplyTopParameters(commandText); + SortParameters(newCommandText, InnerCommand.Parameters); FixParameters(InnerCommand.Parameters); var command = (DbCommand) ((ICloneable) InnerCommand).Clone(); command.CommandText = newCommandText; - JetDataReader dataReader; - - if (skipCount != 0) - dataReader = new JetDataReader( - command.ExecuteReader(behavior), topCount == -1 - ? 0 - : topCount - skipCount, skipCount); - else if (topCount >= 0) - dataReader = new JetDataReader(command.ExecuteReader(behavior), topCount, 0); - else - dataReader = new JetDataReader(command.ExecuteReader(behavior)); + var dataReader = new JetDataReader(command.ExecuteReader(behavior)); _rowCount = dataReader.RecordsAffected; @@ -332,11 +323,10 @@ namespace System.Data.Jet private int InternalExecuteNonQuery(string commandText) { - // ReSharper disable NotAccessedVariable - // ReSharper restore NotAccessedVariable if (!CheckExists(commandText, out var newCommandText)) return 0; - ParseSkipTop(newCommandText, out var topCount, out var skipCount, out newCommandText); + + newCommandText = ApplyTopParameters(newCommandText); SortParameters(newCommandText, InnerCommand.Parameters); FixParameters(InnerCommand.Parameters); @@ -487,71 +477,147 @@ namespace System.Data.Jet return commandText; } - private void ParseSkipTop(string commandText, out int topCount, out int skipCount, out string newCommandText) + private string ApplyTopParameters(string commandText) { - newCommandText = commandText; + // We inline all TOP clause parameters of all SELECT statements, because Jet does not support parameters + // in TOP clauses. + var lastCommandText = commandText; + var parameters = InnerCommand.Parameters.Cast().ToList(); + + while ((commandText = _topRegularExpression.Replace( + commandText, + match => (IsParameter(match.Value) + ? Convert.ToInt32(GetOrExtractParameter(commandText, match.Value, match.Index, parameters).Value) + : int.Parse(match.Value)) + .ToString(), 1)) != lastCommandText) + { + lastCommandText = commandText; + } + + InnerCommand.Parameters.Clear(); + InnerCommand.Parameters.AddRange(parameters.ToArray()); - #region TOP clause + return commandText; + } - topCount = -1; - skipCount = 0; + protected virtual bool IsParameter(string fragment) + => fragment.StartsWith("@") || + fragment.Equals("?"); - var indexOfTop = newCommandText.IndexOf(" top ", StringComparison.InvariantCultureIgnoreCase); - while (indexOfTop != -1) + protected virtual DbParameter GetOrExtractParameter(string commandText, string name, int count, List parameters) + { + if (name.Equals("?")) { - var indexOfTopEnd = newCommandText.IndexOf(" ", indexOfTop + 5, StringComparison.InvariantCultureIgnoreCase); - var stringTopCount = newCommandText.Substring(indexOfTop + 5, indexOfTopEnd - indexOfTop - 5) - .Trim(); - var stringTopCountElements = stringTopCount.Split('+'); - int topCount0; - int topCount1; - - if (stringTopCountElements[0] - .StartsWith("@")) - topCount0 = Convert.ToInt32( - InnerCommand.Parameters[stringTopCountElements[0]] - .Value); - else if (!int.TryParse(stringTopCountElements[0], out topCount0)) - throw new Exception("Invalid TOP clause parameter"); - - if (stringTopCountElements.Length == 1) - topCount1 = 0; - else if (stringTopCountElements[1] - .StartsWith("@")) - topCount1 = Convert.ToInt32( - InnerCommand.Parameters[stringTopCountElements[1]] - .Value); - else if (!int.TryParse(stringTopCountElements[1], out topCount1)) - throw new Exception("Invalid TOP clause parameter"); - - var localTopCount = topCount0 + topCount1; - newCommandText = newCommandText.Remove(indexOfTop + 5, stringTopCount.Length) - .Insert(indexOfTop + 5, localTopCount.ToString()); - if (indexOfTop <= 12) - topCount = localTopCount; - indexOfTop = newCommandText.IndexOf(" top ", indexOfTop + 5, StringComparison.InvariantCultureIgnoreCase); + var index = GetOdbcParameterCount(commandText.Substring(0, count)); + var parameter = InnerCommand.Parameters[index]; + + parameters.RemoveAt(index); + + return parameter; } - #endregion - - #region SKIP clause + return InnerCommand.Parameters[name]; + } - var matchSkipRegularExpression = _skipRegularExpression.Match(newCommandText); - if (matchSkipRegularExpression.Success) + private static int GetOdbcParameterCount(string sqlFragment) + { + var parameterCount = 0; + + // We use '\0' as the default state and char. + var state = '\0'; + var lastChar = '\0'; + + // State machine to count ODBC parameter occurrences. + foreach (var c in sqlFragment) { - var stringSkipCount = matchSkipRegularExpression.Groups["stringSkipCount"] - .Value; - - if (stringSkipCount.StartsWith("@")) - skipCount = Convert.ToInt32( - InnerCommand.Parameters[stringSkipCount] - .Value); - else if (!int.TryParse(stringSkipCount, out skipCount)) - throw new Exception("Invalid SKIP clause parameter"); - newCommandText = newCommandText.Remove(matchSkipRegularExpression.Index, matchSkipRegularExpression.Length); + if (state == '\'') + { + // We are currently inside a string, or closed the string in the last iteration but didn't + // know that at the time, because it still could have been the beginning of an escape sequence. + + if (c == '\'') + { + // We either end the string, begin an escape sequence or end an escape sequence. + if (lastChar == '\'') + { + // This is the end of an escape sequence. + // We continue being in a string. + lastChar = '\0'; + } + else + { + // This is either the beginning of an escape sequence, or the end of the string. + // We will know the in the next iteration. + lastChar = '\''; + } + } + else if (lastChar == '\'') + { + // The last iteration was the end of as string. + // Reset the current state and continue processing the current char. + state = '\0'; + lastChar = '\0'; + } + } + + if (state == '"') + { + // We are currently inside a string, or closed the string in the last iteration but didn't + // know that at the time, because it still could have been the beginning of an escape sequence. + + if (c == '"') + { + // We either end the string, begin an escape sequence or end an escape sequence. + if (lastChar == '"') + { + // This is the end of an escape sequence. + // We continue being in a string. + lastChar = '\0'; + } + else + { + // This is either the beginning of an escape sequence, or the end of the string. + // We will know the in the next iteration. + lastChar = '"'; + } + } + else if (lastChar == '"') + { + // The last iteration was the end of as string. + // Reset the current state and continue processing the current char. + state = '\0'; + lastChar = '\0'; + } + } + + if (state == '\0') + { + if (c == '"') + { + state = '"'; + } + else if (c == '\'') + { + state = '\''; + } + else if (c == '`') + { + state = '`'; + } + else if (c == '?') + { + parameterCount++; + } + } + + if (state == '`' && + c == '`') + { + state = '\0'; + } } - #endregion + return parameterCount; } /// diff --git a/src/System.Data.Jet/JetDataReader.cs b/src/System.Data.Jet/JetDataReader.cs index 9695d5b..48c63d2 100644 --- a/src/System.Data.Jet/JetDataReader.cs +++ b/src/System.Data.Jet/JetDataReader.cs @@ -18,19 +18,7 @@ namespace System.Data.Jet _wrappedDataReader = dataReader; } - public JetDataReader(DbDataReader dataReader, int topCount, int skipCount) - : this(dataReader) - { - _topCount = topCount; - for (var i = 0; i < skipCount; i++) - { - _wrappedDataReader.Read(); - } - } - private readonly DbDataReader _wrappedDataReader; - private readonly int _topCount; - private int _readCount; public override void Close() { @@ -118,34 +106,22 @@ namespace System.Data.Jet => GetDateTime(ordinal) - JetConfiguration.TimeSpanOffset; public virtual DateTimeOffset GetDateTimeOffset(int ordinal) - { - return GetDateTime(ordinal); - } + => GetDateTime(ordinal); public override decimal GetDecimal(int ordinal) - { - return Convert.ToDecimal(_wrappedDataReader.GetValue(ordinal)); - } + => Convert.ToDecimal(_wrappedDataReader.GetValue(ordinal)); public override double GetDouble(int ordinal) - { - return Convert.ToDouble(_wrappedDataReader.GetValue(ordinal)); - } + => Convert.ToDouble(_wrappedDataReader.GetValue(ordinal)); public override System.Collections.IEnumerator GetEnumerator() - { - return _wrappedDataReader.GetEnumerator(); - } + => _wrappedDataReader.GetEnumerator(); public override Type GetFieldType(int ordinal) - { - return _wrappedDataReader.GetFieldType(ordinal); - } + => _wrappedDataReader.GetFieldType(ordinal); public override float GetFloat(int ordinal) - { - return Convert.ToSingle(_wrappedDataReader.GetValue(ordinal)); - } + => Convert.ToSingle(_wrappedDataReader.GetValue(ordinal)); public override Guid GetGuid(int ordinal) { @@ -158,9 +134,7 @@ namespace System.Data.Jet } public override short GetInt16(int ordinal) - { - return Convert.ToInt16(_wrappedDataReader.GetValue(ordinal)); - } + => Convert.ToInt16(_wrappedDataReader.GetValue(ordinal)); public override int GetInt32(int ordinal) { @@ -177,34 +151,22 @@ namespace System.Data.Jet } public override long GetInt64(int ordinal) - { - return Convert.ToInt64(_wrappedDataReader.GetValue(ordinal)); - } + => Convert.ToInt64(_wrappedDataReader.GetValue(ordinal)); public override string GetName(int ordinal) - { - return _wrappedDataReader.GetName(ordinal); - } + => _wrappedDataReader.GetName(ordinal); public override int GetOrdinal(string name) - { - return _wrappedDataReader.GetOrdinal(name); - } + => _wrappedDataReader.GetOrdinal(name); - public override System.Data.DataTable GetSchemaTable() - { - return _wrappedDataReader.GetSchemaTable(); - } + public override DataTable GetSchemaTable() + => _wrappedDataReader.GetSchemaTable(); public override string GetString(int ordinal) - { - return _wrappedDataReader.GetString(ordinal); - } + => _wrappedDataReader.GetString(ordinal); public override object GetValue(int ordinal) - { - return _wrappedDataReader.GetValue(ordinal); - } + => _wrappedDataReader.GetValue(ordinal); public override T GetFieldValue(int ordinal) { @@ -217,9 +179,7 @@ namespace System.Data.Jet } public override int GetValues(object[] values) - { - return _wrappedDataReader.GetValues(values); - } + => _wrappedDataReader.GetValues(values); public override bool HasRows => _wrappedDataReader.HasRows; @@ -237,18 +197,10 @@ namespace System.Data.Jet } public override bool NextResult() - { - return _wrappedDataReader.NextResult(); - } + => _wrappedDataReader.NextResult(); public override bool Read() - { - _readCount++; - if (_topCount != 0 && _readCount > _topCount) - return false; - - return _wrappedDataReader.Read(); - } + => _wrappedDataReader.Read(); public override int RecordsAffected => _wrappedDataReader.RecordsAffected; diff --git a/test/EFCore.Jet.FunctionalTests/TestUtilities/AssertSqlHelper.cs b/test/EFCore.Jet.FunctionalTests/TestUtilities/AssertSqlHelper.cs index 8d9bbfa..3743739 100644 --- a/test/EFCore.Jet.FunctionalTests/TestUtilities/AssertSqlHelper.cs +++ b/test/EFCore.Jet.FunctionalTests/TestUtilities/AssertSqlHelper.cs @@ -1,3 +1,4 @@ +using System; using System.Data.Jet; namespace EntityFrameworkCore.Jet.FunctionalTests.TestUtilities @@ -18,5 +19,13 @@ namespace EntityFrameworkCore.Jet.FunctionalTests.TestUtilities => dataAccessProviderType == DataAccessProviderType.Odbc ? "?" : name; + + public static string Declaration(string fullDeclaration) + => Declaration(fullDeclaration, DataAccessProviderType); + + public static string Declaration(string fullDeclaration, DataAccessProviderType dataAccessProviderType) + => dataAccessProviderType == DataAccessProviderType.Odbc + ? string.Empty + : fullDeclaration; } } \ No newline at end of file