From b047b68e80a9d0948446402dda83d8639444632e Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 16 Jul 2021 12:57:39 -0700 Subject: [PATCH 1/3] Start resource scopes Signed-off-by: Ryan Nett --- .../org/tensorflow/resource/Deallocator.java | 23 + .../org/tensorflow/resource/Resource.java | 22 + .../tensorflow/resource/ResourceManager.java | 37 + .../tensorflow/resource/ResourceScope.java | 127 +++ .../resource/WeakIdentityHashMap.java | 980 ++++++++++++++++++ 5 files changed, 1189 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Deallocator.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Resource.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/WeakIdentityHashMap.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Deallocator.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Deallocator.java new file mode 100644 index 00000000000..622f1851065 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Deallocator.java @@ -0,0 +1,23 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + + */ +package org.tensorflow.resource; + +@FunctionalInterface +public interface Deallocator { + public void deallocate(); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Resource.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Resource.java new file mode 100644 index 00000000000..f7d1f62dcd7 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/Resource.java @@ -0,0 +1,22 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + + */ +package org.tensorflow.resource; + +public interface Resource { + public Deallocator deallocator(); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java new file mode 100644 index 00000000000..6c76f77606f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java @@ -0,0 +1,37 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + + */ +package org.tensorflow.resource; + +import java.util.Collections; +import java.util.Set; + +public class ResourceManager { + private static Set hasGc = Collections.synchronizedSet(Collections.newSetFromMap(new WeakIdentityHashMap<>())); + + public static void cleanup(Resource resource){ + if(hasGc.add(resource)){ + Deallocator d = resource.deallocator(); + if(d != null){ + addGc(resource, d); + } + } + } + public static void addGc(Resource resource, Deallocator deallocator){ + + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java new file mode 100644 index 00000000000..4f3a4d1c6f6 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java @@ -0,0 +1,127 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + + */ +package org.tensorflow.resource; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.Set; + +public class ResourceScope implements Resource, AutoCloseable{ + private static final Map references = new WeakIdentityHashMap<>(); + + protected static void addReference(Resource resource, boolean doGC){ + synchronized (references) { + references.put(resource, references.getOrDefault(resource, 0) + 1); + } + if (doGC) { + ResourceManager.cleanup(resource); + } + } + + protected static void removeReference(Resource resource){ + boolean doDeallocate = false; + synchronized (references) { + if (references.containsKey(resource)) { + int refs = references.get(resource) - 1; + if (refs <= 0) { + references.remove(resource); + doDeallocate = true; + } else { + references.put(resource, refs); + } + } else { + doDeallocate = true; + } + } + // separate to ensure no deadlock from weird deallocators + if(doDeallocate){ + resource.deallocator().deallocate(); + } + } + + private void ensureOpen(){ + synchronized (resources) { + if (isClosed[0]) { + throw new IllegalStateException("Resource scope has been closed"); + } + } + } + + public void add(Resource resource) { + synchronized (resources) { + ensureOpen(); + resources.add(resource); + addReference(resource, isWeak); + } + } + + public void remove(Resource resource) { + synchronized (resources) { + ensureOpen(); + if (resources.remove(resource)) { + removeReference(resource); + } + } + } + + @Override + public void close() throws Exception { + synchronized (resources) { + resources.forEach(ResourceScope::removeReference); + resources.clear(); + isClosed[0] = true; + } + } + + @Override + public Deallocator deallocator() { + return () -> { + synchronized (resources) { + resources.forEach(ResourceScope::removeReference); + resources.clear(); + isClosed[0] = true; + } + }; + } + + ResourceScope(boolean weak){ + isWeak = weak; + if(isWeak){ + resources = Collections.newSetFromMap(new WeakIdentityHashMap<>()); + } else { + resources = Collections.newSetFromMap(new IdentityHashMap<>()); + } + } + + public static ResourceScope strongScope(){ + return new ResourceScope(false); + } + + public static ResourceScope weakScope(){ + return new ResourceScope(true); + } + + private final Set resources; + private final boolean isWeak; + private final boolean[] isClosed = new boolean[]{ false }; + + { + ResourceManager.addGc(this, deallocator()); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/WeakIdentityHashMap.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/WeakIdentityHashMap.java new file mode 100644 index 00000000000..3ac5cf385c4 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/WeakIdentityHashMap.java @@ -0,0 +1,980 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + + */ +package org.tensorflow.resource; + +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; +import java.util.AbstractCollection; +import java.util.AbstractMap; +import java.util.AbstractSet; +import java.util.ArrayList; +import java.util.Collection; +import java.util.ConcurrentModificationException; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Set; + +public class WeakIdentityHashMap extends AbstractMap implements Map { + + /** The default initial capacity -- MUST be a power of two. */ + private static final int DEFAULT_INITIAL_CAPACITY = 16; + + /** + * The maximum capacity, used if a higher value is implicitly specified by either of the + * constructors with arguments. MUST be a power of two ≤ {@code 1<<30}. + */ + private static final int MAXIMUM_CAPACITY = 1 << 30; + + /** The load fast used when none specified in constructor. */ + private static final float DEFAULT_LOAD_FACTOR = 0.75f; + + /** The table, resized as necessary. Length MUST Always be a power of two. */ + private Entry[] table; + + /** The number of key-value mappings contained in this weak hash map. */ + private int size; + + /** The next size value at which to resize (capacity * load factor). */ + private int threshold; + + /** The load factor for the hash table. */ + private final float loadFactor; + + /** Reference queue for cleared WeakEntries. */ + private final ReferenceQueue queue = new ReferenceQueue<>(); + + /** + * The number of times this HashMap has been structurally modified Structural modifications are + * those that change the number of mappings in the HashMap or otherwise modify its internal + * structure (e.g., rehash). This field is used to make iterators on Collection-views of the + * HashMap fail-fast. (See ConcurrentModificationException). + */ + private volatile int modCount; + + /** + * Constructs a new, empty WeakIdentityHashMap with the given initial capacity and + * the given load factor. + * + * @param initialCapacity the initial capacity of the WeakIdentityHashMap + * @param loadFactor the load factor of the WeakIdentityHashMap + * @throws IllegalArgumentException If the initial capacity is negative, or if the load factor is + * nonpositive + */ + public WeakIdentityHashMap(int initialCapacity, float loadFactor) { + if (initialCapacity < 0) { + throw new IllegalArgumentException("Illegal Initial Capacity: " + initialCapacity); + } + if (initialCapacity > MAXIMUM_CAPACITY) { + initialCapacity = MAXIMUM_CAPACITY; + } + + if (loadFactor <= 0 || Float.isNaN(loadFactor)) { + throw new IllegalArgumentException("Illegal Load factor: " + loadFactor); + } + int capacity = 1; + while (capacity < initialCapacity) { + capacity <<= 1; + } + @SuppressWarnings("unchecked") + Entry[] tmpTable = (Entry[]) new Entry[capacity]; + table = tmpTable; + this.loadFactor = loadFactor; + threshold = (int) (capacity * loadFactor); + } + + /** + * Constructs a new, empty WeakIdentityHashMap with the given initial capacity and + * the default load factor, which is 0.75. + * + * @param initialCapacity the initial capacity of the WeakIdentityHashMap + * @throws IllegalArgumentException If the initial capacity is negative + */ + public WeakIdentityHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + /** + * Constructs a new, empty WeakIdentityHashMap with the default initial capacity (16) + * and the default load factor (0.75). + */ + public WeakIdentityHashMap() { + this.loadFactor = DEFAULT_LOAD_FACTOR; + threshold = DEFAULT_INITIAL_CAPACITY; + @SuppressWarnings("unchecked") + Entry[] tmpTable = (Entry[]) new Entry[DEFAULT_INITIAL_CAPACITY]; + table = tmpTable; + } + + /** + * Constructs a new WeakIdentityHashMap with the same mappings as the specified + * Map. The WeakIdentityHashMap is created with default load factor, + * which is 0.75 and an initial capacity sufficient to hold the mappings in the + * specified Map. + * + * @param t the map whose mappings are to be placed in this map + * @throws NullPointerException if the specified map is null + * @since 1.3 + */ + public WeakIdentityHashMap(Map t) { + this(Math.max((int) (t.size() / DEFAULT_LOAD_FACTOR) + 1, 16), DEFAULT_LOAD_FACTOR); + putAll(t); + } + + // internal utilities + + /** Value representing null keys inside tables. */ + // This is problematic because it isn't of the right type. + // We can't lie here to the type system by claiming it is of type K, + // because NULL_KEY is a static field but K is a per-instance type parameter. + private static final Object NULL_KEY = new Object(); + + /** + * Use NULL_KEY for key if it is null. + * + * @param key a key, or null + * @return key if it is null, otherwise {@link NULL_KEY} + */ + // not: "private static K maskNull(K key)" because NULL_KEY isn't of type K. + private static Object maskNull(Object key) { + return (key == null ? NULL_KEY : key); + } + + /** Return internal representation of null key back to caller as null. */ + // Argument is actually either of type K, or is NULL_KEY. + @SuppressWarnings("unchecked") + private static K unmaskNull(K key) { + return (key == NULL_KEY ? null : key); + } + + /** Check for equality of non-null reference x and possibly-null y. Uses identity equality. */ + static boolean eq(Object x, Object y) { + return x == y; + } + + /** Return the hash code for x. */ + static int hasher(Object x) { + return System.identityHashCode(x); + } + + /** Return index for hash code h. */ + static int indexFor(int h, int length) { + return h & (length - 1); + } + + /** Expunge stale entries from the table. */ + @SuppressWarnings("allcheckers:purity") // actually has side effects due to weak pointers + private void expungeStaleEntries() { + Entry e; + // These types look wrong to me. + while ((e = (Entry) queue.poll()) != null) { // unchecked cast + int h = e.hash; + int i = indexFor(h, table.length); + + Entry prev = table[i]; + Entry p = prev; + while (p != null) { + Entry next = p.next; + if (p == e) { + if (prev == e) { + table[i] = next; + } else { + prev.next = next; + } + e.next = null; // Help GC + e.value = null; // " " + size--; + break; + } + prev = p; + p = next; + } + } + } + + /** Return the table after first expunging stale entries. */ + private Entry[] getTable() { + expungeStaleEntries(); + return table; + } + + /** + * Returns the number of key-value mappings in this map. This result is a snapshot, and may not + * reflect unprocessed entries that will be removed before next attempted access because they are + * no longer referenced. + */ + @Override + public int size() { + if (size == 0) { + return 0; + } + expungeStaleEntries(); + return size; + } + + /** + * Returns true if this map contains no key-value mappings. This result is a + * snapshot, and may not reflect unprocessed entries that will be removed before next attempted + * access because they are no longer referenced. + */ + @Override + public boolean isEmpty() { + return size() == 0; + } + + /** + * Returns the value to which the specified key is mapped in this weak hash map, or null + * if the map contains no mapping for this key. A return value of null does + * not necessarily indicate that the map contains no mapping for the key; it is also + * possible that the map explicitly maps the key to null. The containsKey + * method may be used to distinguish these two cases. + * + * @param key the key whose associated value is to be returned + * @return the value to which this map maps the specified key, or null if the map + * contains no mapping for this key. + * @see #put(Object, Object) + */ + @Override + public V get(Object key) { + Object k = maskNull(key); + int h = hasher(k); + Entry[] tab = getTable(); + int index = indexFor(h, tab.length); + Entry e = tab[index]; + while (e != null) { + if (e.hash == h && eq(k, e.get())) { + return e.value; + } + e = e.next; + } + return null; + } + + /** + * Returns true if this map contains a mapping for the specified key. + * + * @param key the key whose presence in this map is to be tested + * @return true if there is a mapping for key; false + * otherwise + */ + @Override + public boolean containsKey(Object key) { + return getEntry(key) != null; + } + + /** + * Returns the entry associated with the specified key in the HashMap. Returns null if the HashMap + * contains no mapping for this key. + */ + Entry getEntry(Object key) { + Object k = maskNull(key); + int h = hasher(k); + Entry[] tab = getTable(); + int index = indexFor(h, tab.length); + Entry e = tab[index]; + while (e != null && !(e.hash == h && eq(k, e.get()))) { + e = e.next; + } + return e; + } + + /** + * Associates the specified value with the specified key in this map. If the map previously + * contained a mapping for this key, the old value is replaced. + * + * @param key key with which the specified value is to be associated + * @param value value to be associated with the specified key + * @return previous value associated with specified key, or null if there was no + * mapping for key. A null return can also indicate that the HashMap previously + * associated null with the specified key. + */ + @SuppressWarnings("NonAtomicVolatileUpdate") + @Override + public V put(K key, V value) { + @SuppressWarnings("unchecked") + K k = (K) maskNull(key); + int h = System.identityHashCode(k); + Entry[] tab = getTable(); + int i = indexFor(h, tab.length); + + for (Entry e = tab[i]; e != null; e = e.next) { + if (h == e.hash && eq(k, e.get())) { + V oldValue = e.value; + if (value != oldValue) { + e.value = value; + } + return oldValue; + } + } + + modCount++; + Entry e = tab[i]; + tab[i] = new Entry(k, value, queue, h, e); + if (++size >= threshold) { + resize(tab.length * 2); + } + return null; + } + + /** + * Rehashes the contents of this map into a new array with a larger capacity. This method is + * called automatically when the number of keys in this map reaches its threshold. + * + *

If current capacity is MAXIMUM_CAPACITY, this method does not resize the map, but sets + * threshold to Integer.MAX_VALUE. This has the effect of preventing future calls. + * + * @param newCapacity the new capacity, MUST be a power of two; must be greater than current + * capacity unless current capacity is MAXIMUM_CAPACITY (in which case value is irrelevant) + */ + void resize(int newCapacity) { + Entry[] oldTable = getTable(); + int oldCapacity = oldTable.length; + if (oldCapacity == MAXIMUM_CAPACITY) { + threshold = Integer.MAX_VALUE; + return; + } + + @SuppressWarnings("unchecked") + Entry[] newTable = (Entry[]) new Entry[newCapacity]; + transfer(oldTable, newTable); + table = newTable; + + /* + * If ignoring null elements and processing ref queue caused massive + * shrinkage, then restore old table. This should be rare, but avoids + * unbounded expansion of garbage-filled tables. + */ + if (size >= threshold / 2) { + threshold = (int) (newCapacity * loadFactor); + } else { + expungeStaleEntries(); + transfer(newTable, oldTable); + table = oldTable; + } + } + + /** Transfer all entries from src to dest tables. */ + private void transfer(Entry[] src, Entry[] dest) { + for (int j = 0; j < src.length; ++j) { + Entry e = src[j]; + src[j] = null; // Help GC (?) + while (e != null) { + Entry next = e.next; + Object key = e.get(); + if (key == null) { + e.next = null; // Help GC + e.value = null; // " " + size--; + } else { + int i = indexFor(e.hash, dest.length); + e.next = dest[i]; + dest[i] = e; + } + e = next; + } + } + } + + /** + * Copies all of the mappings from the specified map to this map These mappings will replace any + * mappings that this map had for any of the keys currently in the specified map. + * + *

+ * + * @param m mappings to be stored in this map + * @throws NullPointerException if the specified map is null + */ + @Override + public void putAll(Map m) { + int numKeysToBeAdded = m.size(); + if (numKeysToBeAdded == 0) { + return; + } + + /* + * Expand the map if the map if the number of mappings to be added + * is greater than or equal to threshold. This is conservative; the + * obvious condition is (m.size() + size) >= threshold, but this + * condition could result in a map with twice the appropriate capacity, + * if the keys to be added overlap with the keys already in this map. + * By using the conservative calculation, we subject ourself + * to at most one extra resize. + */ + if (numKeysToBeAdded > threshold) { + int targetCapacity = (int) (numKeysToBeAdded / loadFactor + 1); + if (targetCapacity > MAXIMUM_CAPACITY) { + targetCapacity = MAXIMUM_CAPACITY; + } + int newCapacity = table.length; + while (newCapacity < targetCapacity) { + newCapacity <<= 1; + } + if (newCapacity > table.length) { + resize(newCapacity); + } + } + + for (Iterator> i = m.entrySet().iterator(); + i.hasNext(); ) { + Map.Entry e = i.next(); + put(e.getKey(), e.getValue()); + } + } + + /** + * Removes the mapping for this key from this map if present. + * + * @param key key whose mapping is to be removed from the map + * @return previous value associated with specified key, or null if there was no + * mapping for key. A null return can also indicate that the map previously + * associated null with the specified key. + */ + @SuppressWarnings("NonAtomicVolatileUpdate") + @Override + public V remove(Object key) { + Object k = maskNull(key); + int h = hasher(k); + Entry[] tab = getTable(); + int i = indexFor(h, tab.length); + Entry prev = tab[i]; + Entry e = prev; + + while (e != null) { + Entry next = e.next; + if (h == e.hash && eq(k, e.get())) { + modCount++; + size--; + if (prev == e) { + tab[i] = next; + } else { + prev.next = next; + } + return e.value; + } + prev = e; + e = next; + } + + return null; + } + + /** Special version of remove needed by Entry set. */ + @SuppressWarnings("NonAtomicVolatileUpdate") + Entry removeMapping(Object o) { + if (!(o instanceof Map.Entry)) { + return null; + } + Entry[] tab = getTable(); + Map.Entry entry = (Map.Entry) o; + Object k = maskNull(entry.getKey()); + int h = hasher(k); + int i = indexFor(h, tab.length); + Entry prev = tab[i]; + Entry e = prev; + + while (e != null) { + Entry next = e.next; + if (h == e.hash && e.equals(entry)) { + modCount++; + size--; + if (prev == e) { + tab[i] = next; + } else { + prev.next = next; + } + return e; + } + prev = e; + e = next; + } + + return null; + } + + /** Removes all mappings from this map. */ + @SuppressWarnings("NonAtomicVolatileUpdate") + @Override + public void clear() { + // clear out ref queue. We don't need to expunge entries + // since table is getting cleared. + while (queue.poll() != null) { + ; + } + + modCount++; + Entry[] tab = table; + for (int i = 0; i < tab.length; ++i) { + tab[i] = null; // Help GC (?) + } + size = 0; + + // Allocation of array may have caused GC, which may have caused + // additional entries to go stale. Removing these entries from the + // reference queue will make them eligible for reclamation. + while (queue.poll() != null) { + ; + } + } + + /** + * Returns true if this map maps one or more keys to the specified value. + * + * @param value value whose presence in this map is to be tested + * @return true if this map maps one or more keys to the specified value. + */ + @Override + public boolean containsValue(Object value) { + if (value == null) { + return containsNullValue(); + } + + Entry[] tab = getTable(); + for (int i = tab.length; i-- > 0; ) { + for (Entry e = tab[i]; e != null; e = e.next) { + if (value.equals(e.value)) { + return true; + } + } + } + return false; + } + + /** Special-case code for containsValue with null argument. */ + private boolean containsNullValue() { + Entry[] tab = getTable(); + for (int i = tab.length; i-- > 0; ) { + for (Entry e = tab[i]; e != null; e = e.next) { + if (e.value == null) { + return true; + } + } + } + return false; + } + + /** The entries in this hash table extend WeakReference, using its main ref field as the key. */ + private static class Entry extends WeakReference implements Map.Entry { + private V value; + private final int hash; + private Entry next; + + /** Create new entry. */ + Entry(K key, V value, ReferenceQueue queue, int hash, Entry next) { + super(key, queue); + this.value = value; + this.hash = hash; + this.next = next; + } + + @Override + public K getKey() { + return WeakIdentityHashMap.unmaskNull(get()); + } + + @Override + public V getValue() { + return value; + } + + @Override + public V setValue(V newValue) { + V oldValue = value; + value = newValue; + return oldValue; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + Map.Entry e = (Map.Entry) o; + Object k1 = getKey(); + Object k2 = e.getKey(); + if (eq(k1, k2)) { + Object v1 = getValue(); + Object v2 = e.getValue(); + if (v1 == v2 || (v1 != null && v1.equals(v2))) { + return true; + } + } + return false; + } + + @Override + public int hashCode() { + Object k = getKey(); + Object v = getValue(); + return ((k == null ? 0 : hasher(k)) ^ (v == null ? 0 : v.hashCode())); + } + + @Override + public String toString() { + return getKey() + "=" + getValue(); + } + } + + private abstract class HashIterator implements Iterator { + int index; + WeakIdentityHashMap.Entry entry = null; + WeakIdentityHashMap.Entry lastReturned = null; + int expectedModCount = modCount; + + /** Strong reference needed to avoid disappearance of key between hasNext and next. */ + Object nextKey = null; + + /** + * Strong reference needed to avoid disappearance of key between nextEntry() and any use of the + * entry + */ + Object currentKey = null; + + HashIterator() { + index = (size() != 0 ? table.length : 0); + } + + @Override + public boolean hasNext() { + WeakIdentityHashMap.Entry[] t = table; + + while (nextKey == null) { + WeakIdentityHashMap.Entry e = entry; + int i = index; + while (e == null && i > 0) { + e = t[--i]; + } + entry = e; + index = i; + if (e == null) { + currentKey = null; + return false; + } + nextKey = e.get(); // hold on to key in strong ref + if (nextKey == null) { + entry = entry.next; + } + } + return true; + } + + /** The common parts of next() across different types of iterators. */ + protected WeakIdentityHashMap.Entry nextEntry() { + if (modCount != expectedModCount) { + throw new ConcurrentModificationException(); + } + if (nextKey == null && !hasNext()) { + throw new NoSuchElementException(); + } + + lastReturned = entry; + entry = entry.next; + currentKey = nextKey; + nextKey = null; + return lastReturned; + } + + @Override + public void remove() { + if (lastReturned == null) { + throw new IllegalStateException(); + } + if (modCount != expectedModCount) { + throw new ConcurrentModificationException(); + } + + WeakIdentityHashMap.this.remove(currentKey); + expectedModCount = modCount; + lastReturned = null; + currentKey = null; + } + } + + private class ValueIterator extends HashIterator { + @Override + public V next() { + return nextEntry().value; + } + } + + private class KeyIterator extends HashIterator { + @Override + public K next() { + return nextEntry().getKey(); + } + } + + private class EntryIterator extends HashIterator> { + @Override + public Map.Entry next() { + return nextEntry(); + } + } + + // Views + + private transient Set> entrySet = null; + private transient volatile Set our_keySet = null; + + /** + * Returns a set view of the keys contained in this map. The set is backed by the map, so changes + * to the map are reflected in the set, and vice-versa. The set supports element removal, which + * removes the corresponding mapping from this map, via the Iterator.remove, + * Set.remove, removeAll, retainAll, and clear + * operations. It does not support the add or addAll operations. + * + * @return a set view of the keys contained in this map + */ + @Override + public Set keySet() { + Set ks = our_keySet; + return (ks != null ? ks : (our_keySet = new KeySet())); + } + + private class KeySet extends AbstractSet { + @Override + public Iterator iterator() { + return new KeyIterator(); + } + + @Override + public int size() { + return WeakIdentityHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return containsKey(o); + } + + @Override + public boolean remove(Object o) { + if (containsKey(o)) { + WeakIdentityHashMap.this.remove(o); + return true; + } else { + return false; + } + } + + @Override + public void clear() { + WeakIdentityHashMap.this.clear(); + } + + @Override + public Object[] toArray() { + Collection c = new ArrayList(size()); + for (Iterator i = iterator(); i.hasNext(); ) { + c.add(i.next()); + } + return c.toArray(); + } + + @Override + public T[] toArray(T[] a) { + Collection c = new ArrayList(size()); + for (Iterator i = iterator(); i.hasNext(); ) { + c.add(i.next()); + } + return c.toArray(a); + } + } + + transient volatile Collection our_values = null; + + /** + * Returns a collection view of the values contained in this map. The collection is backed by the + * map, so changes to the map are reflected in the collection, and vice-versa. The collection + * supports element removal, which removes the corresponding mapping from this map, via the + * Iterator.remove, Collection.remove, removeAll, retainAll + * , and clear operations. It does not support the add or + * addAll operations. + * + * @return a collection view of the values contained in this map + */ + @Override + public Collection values() { + Collection vs = our_values; + return (vs != null ? vs : (our_values = new Values())); + } + + private class Values extends AbstractCollection { + @Override + public Iterator iterator() { + return new ValueIterator(); + } + + @Override + public int size() { + return WeakIdentityHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return containsValue(o); + } + + @Override + public void clear() { + WeakIdentityHashMap.this.clear(); + } + + @Override + public Object[] toArray() { + Collection c = new ArrayList(size()); + for (Iterator i = iterator(); i.hasNext(); ) { + c.add(i.next()); + } + return c.toArray(); + } + + @Override + public T[] toArray(T[] a) { + Collection c = new ArrayList(size()); + for (Iterator i = iterator(); i.hasNext(); ) { + c.add(i.next()); + } + return c.toArray(a); + } + } + + /** + * Returns a collection view of the mappings contained in this map. Each element in the returned + * collection is a Map.Entry. The collection is backed by the map, so changes to the + * map are reflected in the collection, and vice-versa. The collection supports element removal, + * which removes the corresponding mapping from the map, via the Iterator.remove, + * Collection.remove, removeAll, retainAll, and clear + * operations. It does not support the add or addAll operations. + * + * @return a collection view of the mappings contained in this map + * @see java.util.Map.Entry + */ + @Override + public Set> entrySet() { + Set> es = entrySet; + return (es != null ? es : (entrySet = new EntrySet())); + } + + private class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + @Override + public boolean contains(Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + Map.Entry e = (Map.Entry) o; + Object k = e.getKey(); + WeakIdentityHashMap.Entry candidate = getEntry(k); + return candidate != null && candidate.equals(e); + } + + @Override + public boolean remove(Object o) { + return removeMapping(o) != null; + } + + @Override + public int size() { + return WeakIdentityHashMap.this.size(); + } + + @Override + public void clear() { + WeakIdentityHashMap.this.clear(); + } + + @Override + public Object[] toArray() { + Collection> c = new ArrayList>(size()); + for (Iterator> i = iterator(); i.hasNext(); ) { + c.add(new OurSimpleEntry(i.next())); + } + return c.toArray(); + } + + @Override + public T[] toArray(T[] a) { + Collection> c = new ArrayList>(size()); + for (Iterator> i = iterator(); i.hasNext(); ) { + c.add(new OurSimpleEntry(i.next())); + } + return c.toArray(a); + } + } + + /** Version copied from Abstract Map because it is not public. */ + static class OurSimpleEntry implements Map.Entry { + K key; + V value; + + public OurSimpleEntry(K key, V value) { + this.key = key; + this.value = value; + } + + public OurSimpleEntry(Map.Entry e) { + this.key = e.getKey(); + this.value = e.getValue(); + } + + @Override + public K getKey() { + return key; + } + + @Override + public V getValue() { + return value; + } + + @Override + public V setValue(V value) { + V oldValue = this.value; + this.value = value; + return oldValue; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + Map.Entry e = (Map.Entry) o; + return WeakIdentityHashMap.eq(key, e.getKey()) && eq(value, e.getValue()); + } + + @Override + public int hashCode() { + return ((key == null) ? 0 : key.hashCode()) ^ ((value == null) ? 0 : value.hashCode()); + } + + @Override + public String toString() { + return key + "=" + value; + } + + private static boolean eq(Object o1, Object o2) { + return (o1 == null ? o2 == null : o1.equals(o2)); + } + } +} \ No newline at end of file From 03c55f817ecf46e39ea3207736b8b30e33064c42 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 16 Jul 2021 13:36:11 -0700 Subject: [PATCH 2/3] dependencies, implicit/explicit split Signed-off-by: Ryan Nett --- .../tensorflow/resource/ResourceManager.java | 2 +- .../tensorflow/resource/ResourceScope.java | 151 ++++++++++++------ 2 files changed, 107 insertions(+), 46 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java index 6c76f77606f..3840fbefc1b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceManager.java @@ -31,7 +31,7 @@ public static void cleanup(Resource resource){ } } } - public static void addGc(Resource resource, Deallocator deallocator){ + public static void addGc(Object resource, Deallocator deallocator){ } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java index 4f3a4d1c6f6..6b74bb9c6e1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java @@ -1,20 +1,20 @@ /* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= - */ +*/ package org.tensorflow.resource; import java.util.Collections; @@ -22,10 +22,10 @@ import java.util.Map; import java.util.Set; -public class ResourceScope implements Resource, AutoCloseable{ +public class ResourceScope implements AutoCloseable { private static final Map references = new WeakIdentityHashMap<>(); - protected static void addReference(Resource resource, boolean doGC){ + private static void addReference(Resource resource, boolean doGC) { synchronized (references) { references.put(resource, references.getOrDefault(resource, 0) + 1); } @@ -34,7 +34,7 @@ protected static void addReference(Resource resource, boolean doGC){ } } - protected static void removeReference(Resource resource){ + private static void removeReference(Resource resource) { boolean doDeallocate = false; synchronized (references) { if (references.containsKey(resource)) { @@ -49,79 +49,140 @@ protected static void removeReference(Resource resource){ doDeallocate = true; } } - // separate to ensure no deadlock from weird deallocators - if(doDeallocate){ + // separate to ensure no deadlock from deallocators that reference other resources + if (doDeallocate) { resource.deallocator().deallocate(); } } - private void ensureOpen(){ + /** + * Adds a resource to this scope. The resource will be kept alive at least until this scope is closed, + * or it is unreachable if this is a weak scope. + */ + public void add(Resource resource) { synchronized (resources) { - if (isClosed[0]) { - throw new IllegalStateException("Resource scope has been closed"); + ensureOpen(); + if (resources.add(resource)) { + addReference(resource, isWeak); } } } - public void add(Resource resource) { + /** + * Remove a resource from this scope. If this was the only scope referencing it, it will be deallocated. + */ + public void remove(Resource resource) { synchronized (resources) { ensureOpen(); - resources.add(resource); - addReference(resource, isWeak); + if (resources.remove(resource)) { + removeReference(resource); + } } } - public void remove(Resource resource) { + /** Ensures {@code other} is not closed before {@code this}. */ + public void dependsOn(ResourceScope other) { synchronized (resources) { ensureOpen(); - if (resources.remove(resource)) { - removeReference(resource); + if (dependencies.add(other)) { + other.addConsumer(); } } } @Override - public void close() throws Exception { + public void close() { synchronized (resources) { - resources.forEach(ResourceScope::removeReference); - resources.clear(); - isClosed[0] = true; + closeHelper(resources, consumers, dependencies, isClosed); } } - @Override public Deallocator deallocator() { return () -> { synchronized (resources) { - resources.forEach(ResourceScope::removeReference); - resources.clear(); - isClosed[0] = true; + closeHelper(resources, consumers, dependencies, isClosed); } }; } - ResourceScope(boolean weak){ + public static ResourceScope strongScope(boolean implicit) { + return new ResourceScope(false, implicit); + } + + public static ResourceScope weakScope(boolean implicit) { + return new ResourceScope(true, implicit); + } + + public static ResourceScope strongImplicitScope() { + return strongScope(true); + } + + public static ResourceScope strongExplicitScope() { + return strongScope(false); + } + + public static ResourceScope weakImplicitScope() { + return weakScope(true); + } + + public static ResourceScope weakExplicitScope() { + return weakScope(false); + } + + /** + * @param weak will resources be GCd + * @param implicit will the scope itself be closed on GC + */ + ResourceScope(boolean weak, boolean implicit) { isWeak = weak; - if(isWeak){ + isImplicit = implicit; + if (isWeak) { resources = Collections.newSetFromMap(new WeakIdentityHashMap<>()); } else { resources = Collections.newSetFromMap(new IdentityHashMap<>()); } + if (isImplicit) { + ResourceManager.addGc(this, deallocator()); + } } - public static ResourceScope strongScope(){ - return new ResourceScope(false); + private synchronized void addConsumer(){ + consumers++; } - public static ResourceScope weakScope(){ - return new ResourceScope(true); + private synchronized void removeConsumer(){ + consumers--; } - private final Set resources; - private final boolean isWeak; - private final boolean[] isClosed = new boolean[]{ false }; + private void ensureOpen() { + synchronized (resources) { + if (isClosed[0]) { + throw new IllegalStateException("Resource scope has been closed"); + } + } + } - { - ResourceManager.addGc(this, deallocator()); + private static void closeHelper( + Set resources, int consumers, Set dependencies, boolean[] isClosed) { + if (consumers > 0) { + throw new IllegalStateException( + "There are still " + + consumers + + " open scopes with " + + "dependencies on this scope, can not close."); + } + dependencies.forEach(ResourceScope::removeConsumer); + dependencies.clear(); + resources.forEach(ResourceScope::removeReference); + resources.clear(); + isClosed[0] = true; } + + private final Set resources; + private final boolean isWeak; + private final boolean isImplicit; + private final boolean[] isClosed = new boolean[] {false}; + private int consumers = 0; + private final Set dependencies = + Collections.newSetFromMap(new IdentityHashMap<>()); } From e9b497f0e6afdcb78db778d666148d52127b0a4b Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 16 Jul 2021 13:43:07 -0700 Subject: [PATCH 3/3] Better synchronization Signed-off-by: Ryan Nett --- .../tensorflow/resource/ResourceScope.java | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java index 6b74bb9c6e1..8bca7540b36 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/resource/ResourceScope.java @@ -21,6 +21,7 @@ import java.util.IdentityHashMap; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; public class ResourceScope implements AutoCloseable { private static final Map references = new WeakIdentityHashMap<>(); @@ -56,8 +57,8 @@ private static void removeReference(Resource resource) { } /** - * Adds a resource to this scope. The resource will be kept alive at least until this scope is closed, - * or it is unreachable if this is a weak scope. + * Adds a resource to this scope. The resource will be kept alive at least until this scope is + * closed, or it is unreachable if this is a weak scope. */ public void add(Resource resource) { synchronized (resources) { @@ -69,7 +70,8 @@ public void add(Resource resource) { } /** - * Remove a resource from this scope. If this was the only scope referencing it, it will be deallocated. + * Remove a resource from this scope. If this was the only scope referencing it, it will be + * deallocated. */ public void remove(Resource resource) { synchronized (resources) { @@ -146,25 +148,23 @@ public static ResourceScope weakExplicitScope() { } } - private synchronized void addConsumer(){ - consumers++; + private void addConsumer() { + consumers.incrementAndGet(); } - private synchronized void removeConsumer(){ - consumers--; + private void removeConsumer() { + consumers.decrementAndGet(); } private void ensureOpen() { - synchronized (resources) { - if (isClosed[0]) { - throw new IllegalStateException("Resource scope has been closed"); - } + if (isClosed[0]) { + throw new IllegalStateException("Resource scope has been closed"); } } private static void closeHelper( - Set resources, int consumers, Set dependencies, boolean[] isClosed) { - if (consumers > 0) { + Set resources, AtomicInteger consumers, Set dependencies, boolean[] isClosed) { + if (consumers.get() > 0) { throw new IllegalStateException( "There are still " + consumers @@ -182,7 +182,7 @@ private static void closeHelper( private final boolean isWeak; private final boolean isImplicit; private final boolean[] isClosed = new boolean[] {false}; - private int consumers = 0; + private final AtomicInteger consumers = new AtomicInteger(0); private final Set dependencies = Collections.newSetFromMap(new IdentityHashMap<>()); }