diff --git a/src/EFCore.Jet/Query/ExpressionTranslators/Internal/JetObjectToStringTranslator.cs b/src/EFCore.Jet/Query/ExpressionTranslators/Internal/JetObjectToStringTranslator.cs index 5797be3..2e67bff 100644 --- a/src/EFCore.Jet/Query/ExpressionTranslators/Internal/JetObjectToStringTranslator.cs +++ b/src/EFCore.Jet/Query/ExpressionTranslators/Internal/JetObjectToStringTranslator.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Reflection; +using EntityFrameworkCore.Jet.Utilities; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Query; @@ -19,40 +20,76 @@ namespace EntityFrameworkCore.Jet.Query.ExpressionTranslators.Internal { private readonly JetSqlExpressionFactory _sqlExpressionFactory; - private static readonly Type[] _typeMapping = - { - typeof(int), - typeof(long), - typeof(DateTime), - typeof(Guid), - typeof(bool), - typeof(byte), - typeof(byte[]), - typeof(double), - typeof(DateTimeOffset), - typeof(char), - typeof(short), - typeof(float), - typeof(decimal), - typeof(TimeSpan), - typeof(uint), - typeof(ushort), - typeof(ulong), - typeof(sbyte), - }; + private const int DefaultLength = 100; + + private static readonly Dictionary _typeMapping + = new() + { + { typeof(sbyte), "varchar(4)" }, + { typeof(byte), "varchar(3)" }, + { typeof(short), "varchar(6)" }, + { typeof(ushort), "varchar(5)" }, + { typeof(int), "varchar(11)" }, + { typeof(uint), "varchar(10)" }, + { typeof(long), "varchar(20)" }, + { typeof(ulong), "varchar(20)" }, + { typeof(float), $"varchar({DefaultLength})" }, + { typeof(double), $"varchar({DefaultLength})" }, + { typeof(decimal), $"varchar({DefaultLength})" }, + { typeof(char), "varchar(1)" }, + { typeof(DateTime), $"varchar({DefaultLength})" }, + { typeof(DateTimeOffset), $"varchar({DefaultLength})" }, + { typeof(TimeSpan), $"varchar({DefaultLength})" }, + { typeof(Guid), "varchar(36)" }, + { typeof(byte[]), $"varchar({DefaultLength})" }, + }; public JetObjectToStringTranslator(SqlExpressionFactory sqlExpressionFactory) => _sqlExpressionFactory = (JetSqlExpressionFactory)sqlExpressionFactory; - public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) + public virtual SqlExpression Translate( + SqlExpression instance, + MethodInfo method, + IReadOnlyList arguments, + IDiagnosticsLogger logger) { - return method.Name == nameof(ToString) - && arguments.Count == 0 - && instance != null - && _typeMapping.Contains( - instance.Type - .UnwrapNullableType() - .UnwrapEnumType()) + Check.NotNull(method, nameof(method)); + Check.NotNull(arguments, nameof(arguments)); + Check.NotNull(logger, nameof(logger)); + + if (instance == null || method.Name != nameof(ToString) || arguments.Count != 0) + { + return null; + } + + if (instance.Type == typeof(bool)) + { + if (instance is ColumnExpression columnExpression && columnExpression.IsNullable) + { + return _sqlExpressionFactory.Case( + new[] + { + new CaseWhenClause( + _sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)), + _sqlExpressionFactory.Constant(false.ToString())), + new CaseWhenClause( + _sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(true)), + _sqlExpressionFactory.Constant(true.ToString())) + }, + _sqlExpressionFactory.Constant(null)); + } + + return _sqlExpressionFactory.Case( + new[] + { + new CaseWhenClause( + _sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)), + _sqlExpressionFactory.Constant(false.ToString())) + }, + _sqlExpressionFactory.Constant(true.ToString())); + } + + return _typeMapping.TryGetValue(instance.Type, out var storeType) ? _sqlExpressionFactory.Convert(instance, typeof(string)) : null; }