diff --git a/src/SQLite.Net/Orm.cs b/src/SQLite.Net/Orm.cs index 996274a17..19678b3eb 100755 --- a/src/SQLite.Net/Orm.cs +++ b/src/SQLite.Net/Orm.cs @@ -37,11 +37,11 @@ internal static class Orm public const string ImplicitIndexSuffix = "Id"; internal static string SqlDecl(TableMapping.Column p, bool storeDateTimeAsTicks, IBlobSerializer serializer, - IDictionary extraTypeMappings) + IDictionary extraTypeMappings, bool hasCompositePK = false) { var decl = "\"" + p.Name + "\" " + SqlType(p, storeDateTimeAsTicks, serializer, extraTypeMappings) + " "; - if (p.IsPK) + if (p.IsPK && !hasCompositePK) { decl += "primary key "; } diff --git a/src/SQLite.Net/SQLiteConnection.cs b/src/SQLite.Net/SQLiteConnection.cs index 4b7753e43..cd7fa9d6c 100644 --- a/src/SQLite.Net/SQLiteConnection.cs +++ b/src/SQLite.Net/SQLiteConnection.cs @@ -371,7 +371,7 @@ public int CreateTable(Type ty, CreateFlags createFlags = CreateFlags.None) { var map = GetMapping(ty, createFlags); - var query = "create table if not exists \"" + map.TableName + "\"(\n"; + var query = new StringBuilder("create table if not exists \"").Append(map.TableName).Append("\"( \n"); var mapColumns = map.Columns; @@ -380,12 +380,24 @@ public int CreateTable(Type ty, CreateFlags createFlags = CreateFlags.None) throw new Exception("Table has no (public) columns"); } - var decls = mapColumns.Select(p => Orm.SqlDecl(p, StoreDateTimeAsTicks, Serializer, ExtraTypeMappings)); - var decl = string.Join(",\n", decls.ToArray()); - query += decl; - query += ")"; + if (map.HasCompositePK) + { + var compositePK = mapColumns.Where(c => c.IsPK).ToList(); + + var decls = mapColumns.Select(p => Orm.SqlDecl(p, StoreDateTimeAsTicks, Serializer, ExtraTypeMappings, map.HasCompositePK)); + var decl = string.Join(",\n", decls.ToArray()); + query.Append(decl).Append(",\n"); + query.Append("primary key (").Append(string.Join(",", compositePK.Select(pk => pk.Name))).Append(")"); + query.Append(")"); + } + else + { + var decls = mapColumns.Select(p => Orm.SqlDecl(p, StoreDateTimeAsTicks, Serializer, ExtraTypeMappings)); + var decl = string.Join(",\n", decls.ToArray()); + query.Append(decl).Append(")"); + } - var count = Execute(query); + var count = Execute(query.ToString()); if (count == 0) { @@ -558,8 +570,13 @@ private void MigrateTable(TableMapping map) foreach (var p in toBeAdded) { + if (p.IsPK) + { + throw new NotSupportedException("The new column may not have a PRIMARY KEY constraint."); + } + var addCol = "alter table \"" + map.TableName + "\" add column " + - Orm.SqlDecl(p, StoreDateTimeAsTicks, Serializer, ExtraTypeMappings); + Orm.SqlDecl(p, StoreDateTimeAsTicks, Serializer, ExtraTypeMappings, map.HasCompositePK); Execute(addCol); } } @@ -798,7 +815,7 @@ public TableQuery Table() where T : class /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). /// /// - /// The primary key. + /// The primary key. Needs to be Dictionary if table has composite PK. /// /// /// The object with the given primary key. Throws a not found exception @@ -808,7 +825,24 @@ public TableQuery Table() where T : class public T Get(object pk) where T : class { var map = GetMapping(typeof (T)); - return Query(map.GetByPrimaryKeySql, pk).First(); + if (map.HasCompositePK) + { + IDictionary compositePK = pk as Dictionary; + if (compositePK == null) + { + throw new NotSupportedException(map.TableName + " table has a composite primary key. Make sure primary key is passed in as Dictionary."); + } + var cpk = map.CompositePK; + if (compositePK.Keys.Intersect(cpk.Select(p => p.Name)).Count() < cpk.Length) + { + throw new NotSupportedException("Cannot get from " + map.TableName + ": CompositePK mismatch. Make sure PK names are valid."); + } + return Query(map.GetByPrimaryKeySql, compositePK.Values.ToArray()).First(); + } + else + { + return Query(map.GetByPrimaryKeySql, pk).First(); + } } /// @@ -834,7 +868,7 @@ public T Get(Expression> predicate) where T : class /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). /// /// - /// The primary key. + /// The primary key. Needs to be Dictionary if table has composite PK. /// /// /// The object with the given primary key or null @@ -843,8 +877,25 @@ public T Get(Expression> predicate) where T : class [PublicAPI] public T Find(object pk) where T : class { - var map = GetMapping(typeof (T)); - return Query(map.GetByPrimaryKeySql, pk).FirstOrDefault(); + var map = GetMapping(typeof(T)); + if (map.HasCompositePK) + { + IDictionary compositePK = pk as Dictionary; + if (compositePK == null) + { + throw new NotSupportedException(map.TableName + " table has a composite primary key. Make sure primary key is passed in as Dictionary."); + } + var cpk = map.CompositePK; + if (compositePK.Keys.Intersect(cpk.Select(p => p.Name)).Count() < cpk.Length) + { + throw new NotSupportedException("Cannot find in " + map.TableName + ": CompositePK mismatch. Make sure PK names are valid."); + } + return Query(map.GetByPrimaryKeySql, compositePK.Values.ToArray()).FirstOrDefault(); + } + else + { + return Query(map.GetByPrimaryKeySql, pk).FirstOrDefault(); + } } /// @@ -873,7 +924,7 @@ public T FindWithQuery(string query, params object[] args) where T : class /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). /// /// - /// The primary key. + /// The primary key. Needs to be Dictionary if table has composite PK. /// /// /// The TableMapping used to identify the object type. @@ -885,7 +936,24 @@ public T FindWithQuery(string query, params object[] args) where T : class [PublicAPI] public object Find(object pk, TableMapping map) { - return Query(map, map.GetByPrimaryKeySql, pk).FirstOrDefault(); + if (map.HasCompositePK) + { + IDictionary compositePK = pk as Dictionary; + if (compositePK == null) + { + throw new NotSupportedException(map.TableName + " table has a composite primary key. Make sure primary key is passed in as Dictionary."); + } + var cpk = map.CompositePK; + if (compositePK.Keys.Intersect(cpk.Select(p => p.Name)).Count() < cpk.Length) + { + throw new NotSupportedException("Cannot find in " + map.TableName + ": CompositePK mismatch. Make sure PK names are valid."); + } + return Query(map, map.GetByPrimaryKeySql, compositePK.Values.ToArray()).FirstOrDefault(); + } + else + { + return Query(map, map.GetByPrimaryKeySql, pk).FirstOrDefault(); + } } /// @@ -1443,10 +1511,23 @@ public int Insert(object obj, string extra, Type objType) } var map = GetMapping(objType); + TableMapping.Column pk = null; - if (map.PK != null && map.PK.IsAutoGuid) + if (map.HasCompositePK) { - var prop = objType.GetRuntimeProperty(map.PK.PropertyName); + pk = map.CompositePK.FirstOrDefault(p => p.IsAutoGuid); + } + else + { + if (map.PK != null && map.PK.IsAutoGuid) + { + pk = map.PK; + } + } + + if (pk != null) + { + var prop = objType.GetRuntimeProperty(pk.PropertyName); if (prop != null) { if (prop.GetValue(obj, null).Equals(Guid.Empty)) @@ -1551,29 +1632,57 @@ public int Update(object obj, Type objType) } var map = GetMapping(objType); + string q = null; + object[] ps = null; + + if (map.HasCompositePK) + { + var compositePK = map.CompositePK; + var cols = from p in map.Columns + where !compositePK.Any(pk => pk == p) + select p; + + var pslist = (from c in cols + select c.GetValue(obj)).ToList(); + + pslist.AddRange(compositePK.Select(pk => pk.GetValue(obj))); - var pk = map.PK; + q = string.Format("update \"{0}\" set {1} where {2}", map.TableName, + string.Join(",", (from c in cols + select "\"" + c.Name + "\" = ? ").ToArray()), string.Join(" and ", compositePK.Select(pk => "\"" + pk.Name + "\" = ? "))); - if (pk == null) + ps = pslist.ToArray(); + } + else { - throw new NotSupportedException("Cannot update " + map.TableName + ": it has no PK"); + var pk = map.PK; + + if (pk == null) + { + throw new NotSupportedException("Cannot update " + map.TableName + ": it has no PK"); + } + + var cols = from p in map.Columns + where p != pk + select p; + + var vals = from c in cols + select c.GetValue(obj); + var pslist = new List(vals) + { + pk.GetValue(obj) + }; + + q = string.Format("update \"{0}\" set {1} where {2} = ? ", map.TableName, + string.Join(",", (from c in cols + select "\"" + c.Name + "\" = ? ").ToArray()), pk.Name); + + ps = pslist.ToArray(); } - var cols = from p in map.Columns - where p != pk - select p; - var vals = from c in cols - select c.GetValue(obj); - var ps = new List(vals) - { - pk.GetValue(obj) - }; - var q = string.Format("update \"{0}\" set {1} where {2} = ? ", map.TableName, - string.Join(",", (from c in cols - select "\"" + c.Name + "\" = ? ").ToArray()), pk.Name); try { - rowsAffected = Execute(q, ps.ToArray()); + rowsAffected = Execute(q, ps); } catch (SQLiteException ex) { @@ -1635,20 +1744,35 @@ public int UpdateAll(IEnumerable objects, bool runInTransaction = true) public int Delete(object objectToDelete) { var map = GetMapping(objectToDelete.GetType()); - var pk = map.PK; - if (pk == null) + string q = null; + object[] ps = null; + + if (map.HasCompositePK) + { + var compositePK = map.CompositePK; + q = string.Format("delete from \"{0}\" where {1}", map.TableName, string.Join(" and ", compositePK.Select(pk => "\"" + pk.Name + "\" = ? "))); + ps = (from pk in compositePK + select pk.GetValue(objectToDelete)).ToArray(); + } + else { - throw new NotSupportedException("Cannot delete " + map.TableName + ": it has no PK"); + var pk = map.PK; + if (pk == null) + { + throw new NotSupportedException("Cannot delete " + map.TableName + ": it has no PK"); + } + q = string.Format("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + ps = new object[] { pk.GetValue(objectToDelete) }; } - var q = string.Format("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); - return Execute(q, pk.GetValue(objectToDelete)); + + return Execute(q, ps); } /// /// Deletes the object with the specified primary key. /// /// - /// The primary key of the object to delete. + /// The primary key of the object to delete. Needs to be Dictionary if table has composite PK. /// /// /// The number of objects deleted. @@ -1660,13 +1784,34 @@ public int Delete(object objectToDelete) public int Delete(object primaryKey) { var map = GetMapping(typeof (T)); - var pk = map.PK; - if (pk == null) + + if (map.HasCompositePK) { - throw new NotSupportedException("Cannot delete " + map.TableName + ": it has no PK"); + var cpk = map.CompositePK; + IDictionary compositePK = primaryKey as Dictionary; + if (compositePK == null) + { + throw new NotSupportedException(map.TableName + " table has a composite primary key. Make sure primary key is passed in as Dictionary."); + } + if (compositePK.Keys.Intersect(cpk.Select(p => p.Name)).Count() < cpk.Length) + { + throw new NotSupportedException("Cannot delete " + map.TableName + ": CompositePK mismatch. Make sure PK names are valid."); + } + var q = string.Format("delete from \"{0}\" where {1}", map.TableName, string.Join(" and ", compositePK.Keys.Select(pk => "\"" + pk + "\" = ? "))); + var ps = (from pk in compositePK.Values + select pk).ToArray(); + return Execute(q, ps); + } + else + { + var pk = map.PK; + if (pk == null) + { + throw new NotSupportedException("Cannot delete " + map.TableName + ": it has no PK"); + } + var q = string.Format("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + return Execute(q, primaryKey); } - var q = string.Format("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); - return Execute(q, primaryKey); } /// diff --git a/src/SQLite.Net/TableMapping.cs b/src/SQLite.Net/TableMapping.cs index c14a5081b..ba70b8e6e 100644 --- a/src/SQLite.Net/TableMapping.cs +++ b/src/SQLite.Net/TableMapping.cs @@ -34,6 +34,7 @@ public class TableMapping { private readonly Column _autoPk; private Column[] _insertColumns; + private Column _pk; [PublicAPI] public TableMapping(Type type, IEnumerable properties, CreateFlags createFlags = CreateFlags.None) @@ -57,23 +58,40 @@ public TableMapping(Type type, IEnumerable properties, CreateFlags } } Columns = cols.ToArray(); - foreach (var c in Columns) + CompositePK = Columns.Where(col => col.IsPK).ToArray(); + + if (CompositePK.Length > 1) { - if (c.IsAutoInc && c.IsPK) - { - _autoPk = c; - } - if (c.IsPK) + HasCompositePK = true; + } + + if (!HasCompositePK) + { + foreach (var c in Columns) { - PK = c; + if (c.IsAutoInc && c.IsPK) + { + _autoPk = c; + } + + if (c.IsPK) + { + _pk = c; + } } - } - HasAutoIncPK = _autoPk != null; + HasAutoIncPK = _autoPk != null; + } - if (PK != null) + if (CompositePK.Length > 1) + { + string compositePKString = string.Join(" and ", CompositePK.Select(pk => "\"" + pk.Name + "\" = ? ")); + GetByPrimaryKeySql = string.Format("select * from \"{0}\" where {1}", TableName, compositePKString); + } + else if (PK != null) { - GetByPrimaryKeySql = string.Format("select * from \"{0}\" where \"{1}\" = ?", TableName, PK.Name); + string pkString = PK.Name; + GetByPrimaryKeySql = string.Format("select * from \"{0}\" where \"{1}\" = ?", TableName, pkString); } else { @@ -92,7 +110,23 @@ public TableMapping(Type type, IEnumerable properties, CreateFlags public Column[] Columns { get; private set; } [PublicAPI] - public Column PK { get; private set; } + public Column PK + { + get + { + if (HasCompositePK) + { + throw new NotSupportedException("Table has a composite primary key. Use CompositePK property instead."); + } + else + { + return _pk; + } + } + } + + [PublicAPI] + public Column[] CompositePK { get; private set; } [PublicAPI] public string GetByPrimaryKeySql { get; private set; } @@ -100,6 +134,9 @@ public TableMapping(Type type, IEnumerable properties, CreateFlags [PublicAPI] public bool HasAutoIncPK { get; private set; } + [PublicAPI] + public bool HasCompositePK { get; private set; } + [PublicAPI] public Column[] InsertColumns { diff --git a/tests/CreateTableImplicitTest.cs b/tests/CreateTableImplicitTest.cs index 5ef716755..13079a3d3 100644 --- a/tests/CreateTableImplicitTest.cs +++ b/tests/CreateTableImplicitTest.cs @@ -51,6 +51,7 @@ public void ImplicitAutoInc() TableMapping mapping = db.GetMapping(); + Assert.IsFalse(mapping.HasCompositePK); Assert.IsNotNull(mapping.PK); Assert.AreEqual("Id", mapping.PK.Name); Assert.IsTrue(mapping.PK.IsPK); @@ -66,6 +67,7 @@ public void ImplicitAutoIncAsPassedInTypes() TableMapping mapping = db.GetMapping(); + Assert.IsFalse(mapping.HasCompositePK); Assert.IsNotNull(mapping.PK); Assert.AreEqual("Id", mapping.PK.Name); Assert.IsTrue(mapping.PK.IsPK); @@ -94,6 +96,7 @@ public void ImplicitPK() TableMapping mapping = db.GetMapping(); + Assert.IsFalse(mapping.HasCompositePK); Assert.IsNotNull(mapping.PK); Assert.AreEqual("Id", mapping.PK.Name); Assert.IsTrue(mapping.PK.IsPK); @@ -111,6 +114,7 @@ public void ImplicitPKAutoInc() TableMapping mapping = db.GetMapping(); + Assert.IsFalse(mapping.HasCompositePK); Assert.IsNotNull(mapping.PK); Assert.AreEqual("Id", mapping.PK.Name); Assert.IsTrue(mapping.PK.IsPK); @@ -126,6 +130,7 @@ public void ImplicitPKAutoIncAsPassedInTypes() TableMapping mapping = db.GetMapping(); + Assert.IsFalse(mapping.HasCompositePK); Assert.IsNotNull(mapping.PK); Assert.AreEqual("Id", mapping.PK.Name); Assert.IsTrue(mapping.PK.IsPK); @@ -141,6 +146,7 @@ public void ImplicitPkAsPassedInTypes() TableMapping mapping = db.GetMapping(); + Assert.IsFalse(mapping.HasCompositePK); Assert.IsNotNull(mapping.PK); Assert.AreEqual("Id", mapping.PK.Name); Assert.IsTrue(mapping.PK.IsPK); @@ -156,6 +162,7 @@ public void WithoutImplicitMapping() TableMapping mapping = db.GetMapping(); + Assert.IsFalse(mapping.HasCompositePK); Assert.IsNull(mapping.PK); TableMapping.Column column = mapping.Columns[2]; diff --git a/tests/DeleteTest.cs b/tests/DeleteTest.cs index d8bde89e9..6de95a320 100644 --- a/tests/DeleteTest.cs +++ b/tests/DeleteTest.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Linq; using NUnit.Framework; using SQLite.Net.Attributes; @@ -16,12 +17,27 @@ private class TestTable public string Test { get; set; } } + private class TestTableCompositeKey + { + [PrimaryKey] + public int Id { get; set; } + + [PrimaryKey] + public int TestIndex { get; set; } + + public int Datum { get; set; } + public string Test { get; set; } + } + private const int Count = 100; private SQLiteConnection CreateDb() { var db = new TestDb(); + db.CreateTable(); + db.CreateTable(); + var items = from i in Enumerable.Range(0, Count) select new TestTable @@ -31,7 +47,20 @@ from i in Enumerable.Range(0, Count) } ; db.InsertAll(items); - Assert.AreEqual(Count, db.Table().Count()); + + var itemsCompositeKey = + from i in Enumerable.Range(0, Count) + select new TestTableCompositeKey + { + Datum = 1000 + i, + Test = "Hello World", + Id = i, + TestIndex = i + 1 + } + ; + db.InsertAll(itemsCompositeKey); + Assert.AreEqual(Count, db.Table().Count()); + return db; } @@ -129,5 +158,35 @@ public void DeletePKOne() Assert.AreEqual(1, r); Assert.AreEqual(Count - 1, db.Table().Count()); } + + [Test] + public void DeletePKNoneComposite() + { + var db = CreateDb(); + + var compositePK = new Dictionary(); + compositePK.Add("Id", 348597); + compositePK.Add("TestIndex", 348598); + + var r = db.Delete(compositePK); + + Assert.AreEqual(0, r); + Assert.AreEqual(Count, db.Table().Count()); + } + + [Test] + public void DeletePKOneComposite() + { + var db = CreateDb(); + + var compositePK = new Dictionary(); + compositePK.Add("Id", 1); + compositePK.Add("TestIndex", 2); + + var r = db.Delete(compositePK); + + Assert.AreEqual(1, r); + Assert.AreEqual(Count - 1, db.Table().Count()); + } } } \ No newline at end of file