diff --git a/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs b/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs index 445f1b4..5928057 100644 --- a/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs +++ b/src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs @@ -5,11 +5,13 @@ using System.Collections.Generic; using System.Data.Jet; using System.Linq; using System.Linq.Expressions; +using System.Text; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Storage; using EntityFrameworkCore.Jet.Utilities; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -349,8 +351,7 @@ namespace EntityFrameworkCore.Jet.Query.Sql.Internal if (selectExpression.Offset != null) { - // Jet does not support skipping rows. Use client evaluation instead. - throw new InvalidOperationException(CoreStrings.TranslationFailed(selectExpression.Offset)); + throw new InvalidOperationException("Jet does not support skipping rows. Switch to client evaluation explicitly by inserting a call to either AsEnumerable(), AsAsyncEnumerable(), ToList(), or ToListAsync() if needed."); } if (selectExpression.Limit != null) @@ -430,5 +431,8 @@ namespace EntityFrameworkCore.Jet.Query.Sql.Internal return caseExpression; } + + protected override Expression VisitRowNumber(RowNumberExpression rowNumberExpression) + => throw new InvalidOperationException(CoreStrings.TranslationFailed(rowNumberExpression)); } } \ No newline at end of file diff --git a/src/EFCore.Jet/Storage/Internal/JetSqlGenerationHelper.cs b/src/EFCore.Jet/Storage/Internal/JetSqlGenerationHelper.cs index d85e846..0efda05 100644 --- a/src/EFCore.Jet/Storage/Internal/JetSqlGenerationHelper.cs +++ b/src/EFCore.Jet/Storage/Internal/JetSqlGenerationHelper.cs @@ -1,8 +1,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Data.Jet; using System.Text; -using EntityFrameworkCore.Jet.Infrastructure.Internal; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Storage; using EntityFrameworkCore.Jet.Utilities; @@ -15,18 +13,14 @@ namespace EntityFrameworkCore.Jet.Storage.Internal /// public class JetSqlGenerationHelper : RelationalSqlGenerationHelper { - private readonly IJetOptions _jetOptions; - /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. /// public JetSqlGenerationHelper( - [NotNull] RelationalSqlGenerationHelperDependencies dependencies, - [NotNull] IJetOptions jetOptions) + [NotNull] RelationalSqlGenerationHelperDependencies dependencies) : base(dependencies) { - _jetOptions = jetOptions; } /// @@ -91,23 +85,6 @@ namespace EntityFrameworkCore.Jet.Storage.Internal return DelimitIdentifier(Check.NotEmpty(name, nameof(name))); } - public override string GenerateParameterNamePlaceholder(string name) - => _jetOptions.DataAccessProviderType == DataAccessProviderType.OleDb - ? base.GenerateParameterNamePlaceholder(name) - : "?"; - - public override void GenerateParameterNamePlaceholder(StringBuilder builder, string name) - { - if (_jetOptions.DataAccessProviderType == DataAccessProviderType.OleDb) - { - base.GenerateParameterNamePlaceholder(builder, name); - } - else - { - builder.Append("?"); - } - } - public static string TruncateIdentifier(string identifier) { if (identifier.Length <= 64) diff --git a/src/System.Data.Jet/AdoxWrapper.cs b/src/System.Data.Jet/AdoxWrapper.cs index b1f0de4..8933334 100644 --- a/src/System.Data.Jet/AdoxWrapper.cs +++ b/src/System.Data.Jet/AdoxWrapper.cs @@ -95,7 +95,7 @@ namespace System.Data.Jet connection.DataAccessProviderFactory = dataAccessProviderFactory; connection.Open(); - var sql = @"CREATE TABLE `MSysAccessStorage` ( + var script = @"CREATE TABLE `MSysAccessStorage` ( `DateCreate` DATETIME NULL, `DateUpdate` DATETIME NULL, `Id` COUNTER NOT NULL, @@ -108,8 +108,11 @@ namespace System.Data.Jet CREATE UNIQUE INDEX `ParentIdId` ON `MSysAccessStorage` (`ParentId`, `Id`); CREATE UNIQUE INDEX `ParentIdName` ON `MSysAccessStorage` (`ParentId`, `Name`);"; - using var command = connection.CreateCommand(sql); - command.ExecuteNonQuery(); + foreach (var commandText in script.Split(new[] { ';' }, StringSplitOptions.RemoveEmptyEntries)) + { + using var command = connection.CreateCommand(commandText); + command.ExecuteNonQuery(); + } } catch (Exception e) { diff --git a/src/System.Data.Jet/JetCommand.cs b/src/System.Data.Jet/JetCommand.cs index 99e4e29..1169abe 100644 --- a/src/System.Data.Jet/JetCommand.cs +++ b/src/System.Data.Jet/JetCommand.cs @@ -18,7 +18,8 @@ namespace System.Data.Jet private Guid? _lastGuid; private int? _rowCount; - private static readonly Regex _topRegularExpression = new Regex(@"(?<=(?:^|\s)select\s+top\s+)(?:\d+|(?:@\w+)|\?)(?=\s)", RegexOptions.IgnoreCase); + private static readonly Regex _createProcedureExpression = new Regex(@"^\s*create\s*procedure\b", RegexOptions.IgnoreCase); + private static readonly Regex _topParameterRegularExpression = new Regex(@"(?<=(?:^|\s)select\s+top\s+)(?:@\w+|\?)(?=\s)", RegexOptions.IgnoreCase); private static readonly Regex _selectRowCountRegularExpression = new Regex(@"^\s*select\s*@@rowcount\s*;?\s*$", RegexOptions.IgnoreCase); private static readonly Regex _ifStatementRegex = new Regex(@"^\s*if\s*(?not)?\s*exists\s*\((?.+)\)\s*then\s*(?.*)$", RegexOptions.IgnoreCase); @@ -176,6 +177,8 @@ namespace System.Data.Jet // OLE DB forces us to use an existing active transaction, if one is available. InnerCommand.Transaction = _transaction?.WrappedTransaction ?? _connection.ActiveTransaction?.WrappedTransaction; + ExpandParameters(); + LogHelper.ShowCommandText("ExecuteDbDataReader", InnerCommand); if (JetStoreSchemaDefinitionRetrieve.TryGetDataReaderFromShowCommand(InnerCommand, _connection.JetFactory.InnerFactory, out var dataReader)) @@ -185,19 +188,17 @@ namespace System.Data.Jet if (InnerCommand.CommandType != CommandType.Text) return new JetDataReader(InnerCommand.ExecuteReader(behavior)); - var commandTextList = SplitCommands(InnerCommand.CommandText); - - dataReader = null; - foreach (var t in commandTextList) + if ((dataReader = TryGetDataReaderForSelectRowCount(InnerCommand.CommandText)) == null) { - var commandText = t; - if ((dataReader = TryGetDataReaderForSelectRowCount(commandText)) != null) - continue; + InnerCommand.CommandText = ParseIdentity(InnerCommand.CommandText); + InnerCommand.CommandText = ParseGuid(InnerCommand.CommandText); - commandText = ParseIdentity(commandText); - commandText = ParseGuid(commandText); + InlineTopParameters(); + FixParameters(); - dataReader = InternalExecuteDbDataReader(commandText, behavior); + dataReader = new JetDataReader(InnerCommand.ExecuteReader(behavior)); + + _rowCount = dataReader.RecordsAffected; } return dataReader; @@ -205,8 +206,7 @@ namespace System.Data.Jet private DbDataReader TryGetDataReaderForSelectRowCount(string commandText) { - if (_selectRowCountRegularExpression.Match(commandText) - .Success) + if (_selectRowCountRegularExpression.Match(commandText).Success) { if (_rowCount == null) throw new InvalidOperationException("Invalid " + commandText + ". Run a DataReader before."); @@ -227,6 +227,8 @@ namespace System.Data.Jet { if (Connection == null) throw new InvalidOperationException(Messages.PropertyNotInitialized(nameof(Connection))); + + ExpandParameters(); LogHelper.ShowCommandText("ExecuteNonQuery", InnerCommand); @@ -245,29 +247,30 @@ namespace System.Data.Jet if (InnerCommand.CommandType != CommandType.Text) return InnerCommand.ExecuteNonQuery(); + + if (_selectRowCountRegularExpression.Match(InnerCommand.CommandText) + .Success) + { + // TODO: Fix exception message. + if (_rowCount == null) + throw new InvalidOperationException("Invalid " + InnerCommand.CommandText + ". Run a DataReader before."); + return _rowCount.Value; + } - var commandTextList = SplitCommands(InnerCommand.CommandText); + InnerCommand.CommandText = ParseIdentity(InnerCommand.CommandText); + InnerCommand.CommandText = ParseGuid(InnerCommand.CommandText); - var returnValue = -1; - foreach (var t in commandTextList) - { - var commandText = t; - if (_selectRowCountRegularExpression.Match(commandText) - .Success) - { - if (_rowCount == null) - throw new InvalidOperationException("Invalid " + commandText + ". Run a DataReader before."); - returnValue = _rowCount.Value; - continue; - } + if (!CheckExists(InnerCommand.CommandText, out var newCommandText)) + return 0; - commandText = ParseIdentity(commandText); - commandText = ParseGuid(commandText); + InnerCommand.CommandText = newCommandText; + + InlineTopParameters(); + FixParameters(); - returnValue = InternalExecuteNonQuery(commandText); - } + _rowCount = InnerCommand.ExecuteNonQuery(); - return returnValue; + return _rowCount.Value; } /// @@ -287,6 +290,8 @@ namespace System.Data.Jet // OLE DB forces us to use an existing active transaction, if one is available. InnerCommand.Transaction = _transaction?.WrappedTransaction ?? _connection.ActiveTransaction?.WrappedTransaction; + ExpandParameters(); + LogHelper.ShowCommandText("ExecuteScalar", InnerCommand); if (JetStoreSchemaDefinitionRetrieve.TryGetDataReaderFromShowCommand(InnerCommand, _connection.JetFactory.InnerFactory, out var dataReader)) @@ -301,42 +306,10 @@ namespace System.Data.Jet return DBNull.Value; } - return InnerCommand.ExecuteScalar(); - } - - private JetDataReader InternalExecuteDbDataReader(string commandText, CommandBehavior behavior) - { - var newCommandText = ApplyTopParameters(commandText); - - SortParameters(newCommandText, InnerCommand.Parameters); - FixParameters(InnerCommand.Parameters); - - var command = (DbCommand) ((ICloneable) InnerCommand).Clone(); - command.CommandText = newCommandText; - - var dataReader = new JetDataReader(command.ExecuteReader(behavior)); - - _rowCount = dataReader.RecordsAffected; - - return dataReader; - } - - private int InternalExecuteNonQuery(string commandText) - { - if (!CheckExists(commandText, out var newCommandText)) - return 0; - - newCommandText = ApplyTopParameters(newCommandText); - - SortParameters(newCommandText, InnerCommand.Parameters); - FixParameters(InnerCommand.Parameters); - - var command = (DbCommand) ((ICloneable) InnerCommand).Clone(); - command.CommandText = newCommandText; - - _rowCount = command.ExecuteNonQuery(); + InlineTopParameters(); + FixParameters(); - return _rowCount.Value; + return InnerCommand.ExecuteScalar(); } private bool CheckExists(string commandText, out string newCommandText) @@ -368,10 +341,13 @@ namespace System.Data.Jet return hasRows; } - private void FixParameters(DbParameterCollection parameters) + private void FixParameters() { + var parameters = InnerCommand.Parameters; + if (parameters.Count == 0) return; + foreach (DbParameter parameter in parameters) { if (parameter.Value is TimeSpan ts) @@ -384,65 +360,9 @@ namespace System.Data.Jet } } - private void SortParameters(string query, DbParameterCollection parameters) - { - if (parameters.Count == 0) - return; - - var parameterArray = parameters.Cast() - .OrderBy(p => p, new ParameterPositionComparer(query)) - .ToArray(); - - parameters.Clear(); - parameters.AddRange(parameterArray); - } - - private class ParameterPositionComparer : IComparer - { - private readonly string _query; - - public ParameterPositionComparer(string query) - { - _query = query; - } - - public int Compare(DbParameter x, DbParameter y) - { - if (x == null) - throw new ArgumentNullException(nameof(x)); - if (y == null) - throw new ArgumentNullException(nameof(y)); - - var xPosition = _query.IndexOf(x.ParameterName, StringComparison.Ordinal); - var yPosition = _query.IndexOf(y.ParameterName, StringComparison.Ordinal); - if (xPosition == -1) - xPosition = int.MaxValue; - if (yPosition == -1) - yPosition = int.MaxValue; - return xPosition.CompareTo(yPosition); - } - } - - private string[] SplitCommands(string command) - { - var commandParts = - command.Replace("\r\n", "\n") - .Replace("\r", "\n") - .Split(new[] {";\n"}, StringSplitOptions.None); - var commands = new List(commandParts.Length); - foreach (var commandPart in commandParts) - { - if (!string.IsNullOrWhiteSpace( - commandPart.Replace("\n", "") - .Replace(";", ""))) - commands.Add(commandPart); - } - - return commands.ToArray(); - } - private string ParseIdentity(string commandText) { + // TODO: Fix the following code, that does work only for common scenarios. Use state machine instead. if (commandText.ToLower() .Contains("@@identity")) { @@ -460,6 +380,7 @@ namespace System.Data.Jet private string ParseGuid(string commandText) { + // TODO: Fix the following code, that does work only for common scenarios. Use state machine instead. while (commandText.ToLower() .Contains("newguid()")) { @@ -477,59 +398,160 @@ namespace System.Data.Jet return commandText; } - private string ApplyTopParameters(string commandText) + private void InlineTopParameters() { // We inline all TOP clause parameters of all SELECT statements, because Jet does not support parameters // in TOP clauses. - var lastCommandText = commandText; var parameters = InnerCommand.Parameters.Cast().ToList(); - - while ((commandText = _topRegularExpression.Replace( - commandText, - match => (IsParameter(match.Value) - ? Convert.ToInt32(GetOrExtractParameter(commandText, match.Value, match.Index, parameters).Value) - : int.Parse(match.Value)) - .ToString(), 1)) != lastCommandText) + + if (parameters.Count > 0) { - lastCommandText = commandText; - } - - InnerCommand.Parameters.Clear(); - InnerCommand.Parameters.AddRange(parameters.ToArray()); + var lastCommandText = InnerCommand.CommandText; + var commandText = lastCommandText; - return commandText; + while ((commandText = _topParameterRegularExpression.Replace( + lastCommandText, + match => Convert.ToInt32(ExtractParameter(commandText, match.Value, match.Index, parameters).Value).ToString(), + 1)) != lastCommandText) + { + lastCommandText = commandText; + } + + InnerCommand.CommandText = commandText; + + InnerCommand.Parameters.Clear(); + InnerCommand.Parameters.AddRange(parameters.ToArray()); + } } protected virtual bool IsParameter(string fragment) => fragment.StartsWith("@") || fragment.Equals("?"); - protected virtual DbParameter GetOrExtractParameter(string commandText, string name, int count, List parameters) + protected virtual DbParameter ExtractParameter(string commandText, string name, int count, List parameters) { - if (name.Equals("?")) + var indices = GetParameterIndices(commandText.Substring(0, count)); + var parameter = InnerCommand.Parameters[indices.Count]; + + parameters.RemoveAt(indices.Count); + + return parameter; + } + + protected virtual void ExpandParameters() + { + if (_createProcedureExpression.IsMatch(InnerCommand.CommandText)) + { + return; + } + + var indices = GetParameterIndices(InnerCommand.CommandText); + + if (indices.Count <= 0) + { + return; + } + + var placeholders = GetParameterPlaceholders(InnerCommand.CommandText, indices); + + if (placeholders.All(t => t.Name.StartsWith("@"))) { - var index = GetOdbcParameterCount(commandText.Substring(0, count)); - var parameter = InnerCommand.Parameters[index]; + MatchParametersAndPlaceholders(placeholders); + + if (JetConnection.GetDataAccessProviderType(_connection.DataAccessProviderFactory) == DataAccessProviderType.Odbc) + { + foreach (var placeholder in placeholders.Reverse()) + { + InnerCommand.CommandText = InnerCommand.CommandText + .Remove(placeholder.Index, placeholder.Name.Length) + .Insert(placeholder.Index, "?"); + } + } - parameters.RemoveAt(index); + InnerCommand.Parameters.Clear(); + InnerCommand.Parameters.AddRange(placeholders.Select(p => p.Parameter).ToArray()); + } + else if (placeholders.All(t => t.Name == "?")) + { + throw new InvalidOperationException("Parameter placeholder count does not match parameter count."); + } + else + { + throw new InvalidOperationException("Inconsistent parameter placeholder naming used."); + } + } + + protected virtual void MatchParametersAndPlaceholders(IReadOnlyList placeholders) + { + var unusedParameters = InnerCommand.Parameters + .Cast() + .ToList(); + + foreach (var placeholder in placeholders) + { + var parameter = unusedParameters + .FirstOrDefault(p => placeholder.Name.Equals(p.ParameterName, StringComparison.Ordinal)); + + if (parameter != null) + { + placeholder.Parameter = parameter; + unusedParameters.Remove(parameter); + } + else + { + parameter = placeholders + .FirstOrDefault(p => placeholder.Name.Equals(p.Name, StringComparison.Ordinal)) + ?.Parameter; + + if (parameter == null) + { + throw new InvalidOperationException($"Cannot find parameter with same name as parameter placeholder \"{placeholder.Name}\"."); + } + + var newParameter = (DbParameter) (parameter as ICloneable)?.Clone(); + + if (newParameter == null) + { + throw new InvalidOperationException($"Cannot clone parameter \"{parameter.ParameterName}\"."); + } + + placeholder.Parameter = newParameter; + } + } + } + + protected virtual IReadOnlyList GetParameterPlaceholders(string commandText, IEnumerable indices) + { + var placeholders = new List(); + + foreach (var index in indices) + { + var match = Regex.Match(commandText.Substring(index), @"^(?:\?|@\w+)"); + + if (!match.Success) + { + throw new InvalidOperationException("Invalid parameter placeholder found."); + } - return parameter; + placeholders.Add(new ParameterPlaceholder{ Index = index, Name = match.Value }); } - return InnerCommand.Parameters[name]; + return placeholders.AsReadOnly(); } - private static int GetOdbcParameterCount(string sqlFragment) + protected virtual IReadOnlyList GetParameterIndices(string sqlFragment) { - var parameterCount = 0; + var parameterIndices = new List(); // We use '\0' as the default state and char. var state = '\0'; var lastChar = '\0'; // State machine to count ODBC parameter occurrences. - foreach (var c in sqlFragment) + for (var i = 0; i < sqlFragment.Length; i++) { + var c = sqlFragment[i]; + if (state == '\'') { // We are currently inside a string, or closed the string in the last iteration but didn't @@ -604,9 +626,10 @@ namespace System.Data.Jet { state = '`'; } - else if (c == '?') + else if (c == '?' || + c == '@') { - parameterCount++; + parameterIndices.Add(i); } } @@ -617,7 +640,7 @@ namespace System.Data.Jet } } - return parameterCount; + return parameterIndices.AsReadOnly(); } /// @@ -657,5 +680,12 @@ namespace System.Data.Jet /// The created object object ICloneable.Clone() => new JetCommand(this); + + protected class ParameterPlaceholder + { + public int Index { get; set; } + public string Name { get; set; } + public DbParameter Parameter { get; set; } + } } } \ No newline at end of file