Skip to content

Hyperbolic Graph Conv #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
511 changes: 511 additions & 0 deletions HGCN - Hyperlib.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Creating a hyperbolic neural network using Keras:
```python
import tensorflow as tf
from tensorflow import keras
from hyperlib.nn.layers.lin_hyp import LinearHyperbolic
from hyperlib.nn.layers.linear import LinearHyperbolic
from hyperlib.nn.optimizers.rsgd import RSGD
from hyperlib.manifold.poincare import Poincare

Expand Down
40 changes: 28 additions & 12 deletions hyperlib/manifold/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class Lorentz(Manifold):
"""
Implementation of the Lorentz/Hyperboloid manifold defined by
:math: `L = \{ x \in R^d | -x_0^2 + x_1^2 + ... + x_d^2 = -K \}`,
:math: `L = \{ x \in R^d | -x_0^2 + x_1^2 + ... + x_d^2 = -K \}`,
where c = 1 / K is the hyperbolic curvature and d is the manifold dimension.

The point :math: `( \sqrt{K}, 0, \dots, 0 )` is referred to as "zero".
Expand Down Expand Up @@ -35,7 +35,7 @@ def minkowski_norm(self, u, keepdim=True):
def dist_squared(self, x, y, c):
"""Squared hyperbolic distance between x, y"""
K = 1. / c
theta = tf.clip_by_value( -self.minkowski_dot(x, y) / K,
theta = tf.clip_by_value( -self.minkowski_dot(x, y) / K,
clip_value_min=1.0 + self.eps[x.dtype], clip_value_max=self.max_norm)
return K * arcosh(theta)**2

Expand All @@ -44,9 +44,9 @@ def proj(self, x, c):
K = 1. / c
d1 = x.shape[-1]
y = x[:,1:d1]
y_sqnorm = tf.math.square(
y_sqnorm = tf.math.square(
tf.norm(y, ord=2, axis=1, keepdims=True))
t = tf.clip_by_value(K + y_sqnorm,
t = tf.clip_by_value(K + y_sqnorm,
clip_value_min=self.eps[x.dtype],
clip_value_max=self.max_norm
)
Expand All @@ -70,7 +70,7 @@ def proj_tan0(self, u, c):
return tf.concat([z, ud], axis=1)

def expmap(self, u, x, c):
"""Maps vector u in the tangent space at x onto the manifold"""
"""Maps vector u in the tangent space at x onto the manifold"""
K = 1. / c
sqrtK = K ** 0.5
normu = self.minkowski_norm(u)
Expand All @@ -83,8 +83,8 @@ def expmap(self, u, x, c):
def logmap(self, y, x, c):
"""Maps point y in the manifold to the tangent space at x"""
K = 1. / c
xy = tf.clip_by_value(self.minkowski_dot(x, y) + K,
clip_value_min=-self.max_norm, clip_value_max=-self.eps[x.dtype])
xy = tf.clip_by_value(self.minkowski_dot(x, y) + K,
clip_value_min=-self.max_norm, clip_value_max=-self.eps[x.dtype])
xy -= K
u = y + xy * x * c
normu = self.minkowski_norm(u)
Expand All @@ -99,7 +99,8 @@ def hyp_act(self, act, x, c_in, c_out):
return self.proj(self.expmap0(xt, c=c_out), c=c_out)

def expmap0(self, u, c):
"""Maps vector u in the tangent space at zero onto the manifold"""
"""Maps vector u in the tangent space at zero onto the manifold"""
print('x shape empmap0', u.shape)
K = 1. / c
sqrtK = K ** 0.5
d = u.shape[-1]
Expand All @@ -125,10 +126,16 @@ def logmap0(self, x, c):
y = tf.reshape(x[:,1:], [-1, d-1])
y_norm = tf.norm(y, ord=2, axis=1, keepdims=True)
y_norm = self.clip_norm(y_norm)
theta = tf.clip_by_value(x[:, 0:1] / sqrtK,
clip_value_min=1.0+self.eps[x.dtype], clip_value_max=self.max_norm)

theta = tf.clip_by_value(
x[:, 0:1] / sqrtK,
clip_value_min=1.0+self.eps[x.dtype],
clip_value_max=self.max_norm
)

res = sqrtK * arcosh(theta) * y / y_norm
zeros = tf.zeros((b,1), dtype=res.dtype)

zeros = tf.zeros((b, 1), dtype=res.dtype)
return tf.concat([zeros, res], axis=1)

def mobius_add(self, x, y, c):
Expand All @@ -138,7 +145,7 @@ def mobius_add(self, x, y, c):

def mobius_matvec(self, m, x, c):
u = self.logmap0(x, c)
mu = u @ m
mu = u @ m
return self.expmap0(mu, c)

def ptransp(self, x, y, u, c):
Expand Down Expand Up @@ -178,3 +185,12 @@ def to_poincare(self, x, c):

def clip_norm(self, x):
return tf.clip_by_value(x, clip_value_min=self.min_norm, clip_value_max=self.max_norm)

def sqdist(self, x, y, c):
K = 1. / c
prod = self.minkowski_dot(x, y)
theta = tf.clip_by_value(-prod / K, clip_value_min=1.0 + self.eps[x.dtype], clip_value_max=tf.math.reduce_max(-prod / K))
sqdist = K * arcosh(theta) ** 2
# clamp distance to avoid nans in Fermi-Dirac decoder
res = tf.clip_by_value(sqdist, clip_value_min= tf.math.reduce_min(sqdist), clip_value_max=50.0)
return res
16 changes: 12 additions & 4 deletions hyperlib/manifold/poincare.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def proj(self, x, c):
projected = x / norm * maxnorm
return tf.where(cond, projected, x)

def mobius_add(self, x, y, c):
def mobius_add(self, x, y, c, axis=-1):
"""Element-wise Mobius addition.
Args:
x: Tensor of size B x dimension representing hyperbolic points.
Expand All @@ -146,9 +146,9 @@ def mobius_add(self, x, y, c):
Tensor of shape B x dimension representing the element-wise Mobius addition
of x and y.
"""
cx2 = c * tf.reduce_sum(x * x, axis=-1, keepdims=True)
cy2 = c * tf.reduce_sum(y * y, axis=-1, keepdims=True)
cxy = c * tf.reduce_sum(x * y, axis=-1, keepdims=True)
cx2 = c * tf.reduce_sum(x * x, axis=axis, keepdims=True)
cy2 = c * tf.reduce_sum(y * y, axis=axis, keepdims=True)
cxy = c * tf.reduce_sum(x * y, axis=axis, keepdims=True)
num = (1 + 2 * cxy + cy2) * x + (1 - cx2) * y
denom = 1 + 2 * cxy + cx2 * cy2
return self.proj(num / tf.maximum(denom, self.min_norm), c)
Expand All @@ -174,3 +174,11 @@ def single_query_attn_scores(self, key, query, c):
scores = (1. / denom) * scores
return scores

def sqdist(self, p1, p2, c):
sqrt_c = c ** 0.5
dist_c = atanh(
sqrt_c * self.mobius_add(-p1, p2, c, dim=-1)
)
dist_c = tf.norm(dist_c, axis=-1, ord=2, keepdim=False)
dist = dist_c * 2 / sqrt_c
return dist ** 2
83 changes: 83 additions & 0 deletions hyperlib/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging

import tensorflow as tf
from tensorflow import keras

from hyperlib.loss.constrastive_loss import contrastive_loss
from hyperlib.manifold.lorentz import Lorentz
from hyperlib.manifold.poincare import Poincare
from hyperlib.nn.layers.graph import HGCLayer


log = logging.getLogger(__name__)


class HGCN(tf.keras.Model):
"""
Hierarchical Embeddings model from Poincaré Embeddings for
Learning Hierarchical Representations by Nickel and Keila
Please find an example of how to use this model in hyperlib/examples/wordnet_embedding.py
"""

def __init__(self, vocab, embedding_dim=2, manifold=Poincare, c=1.0, clip_value=0.9):
super().__init__()

initializer=keras.initializers.RandomUniform(minval=-0.001, maxval=0.001, seed=None)
self.string_lookup = keras.layers.StringLookup(vocabulary=vocab, name="string_lookup")
self.embedding = keras.layers.Embedding(
len(vocab)+1,
embedding_dim,
embeddings_initializer=initializer,
name="embeddings",
)
self.vocab = vocab
self.manifold = manifold()
self.c = c
self.clip_value = clip_value

def call(self, inputs):
indices = self.string_lookup(inputs)
return self.embedding(indices)

def get_embeddings(self):
embeddings = self.embedding(tf.constant([i for i in range(len(self.vocab))]))
embeddings_copy = tf.identity(embeddings)
embeddings_hyperbolic = self.manifold.expmap0(embeddings_copy, c=self.c)
return embeddings_hyperbolic

def get_vocabulary(self):
return self.vocab

@staticmethod
def get_model(vocab, embedding_dim=2):
embedding_dim=2
initializer=keras.initializers.RandomUniform(minval=-0.001, maxval=0.001, seed=None)
string_lookup_layer = keras.layers.StringLookup(vocabulary=vocab)

emb_layer = keras.layers.Embedding(
len(vocab)+1,
embedding_dim,
embeddings_initializer=initializer,
name="embeddings",
)

model = keras.Sequential([string_lookup_layer, emb_layer])
return model

def fit(self, train_dataset, optimizer, epochs=100):

for epoch in range(epochs):
log.info("Epoch %d" % (epoch,))
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
pos_embs = self.embedding(self.string_lookup(x_batch_train))
neg_embs = self.embedding(self.string_lookup(y_batch_train))
loss_value = contrastive_loss(
pos_embs, neg_embs, self.manifold, c=self.c, clip_value=self.clip_value)

grads = tape.gradient(loss_value, self.embedding.trainable_weights)
optimizer.apply_gradients(zip(grads, self.embedding.trainable_weights))

if step % 100 == 0:
log.info("Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value)))
99 changes: 99 additions & 0 deletions hyperlib/nn/layers/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import tensorflow as tf
from tensorflow import keras

from .linear import LinearHyperbolic, ActivationHyperbolic
from hyperlib.manifold.lorentz import Lorentz
from hyperlib.manifold.poincare import Poincare


class HyperbolicAggregation(keras.layers.Layer):

def __init__(self, manifold, c):
super().__init__()
self.manifold = manifold
self.c = c

def call(self, inputs):
x_tangent, adj = inputs
support_t = tf.sparse.sparse_dense_matmul(adj, x_tangent)
#support_t = tf.linalg.matmul(adj, x_tangent)
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
return output


class HGCLayer(keras.layers.Layer):
def __init__(self, manifold, input_size, c, activation):
super().__init__()

self.manifold = manifold
self.c = tf.Variable([c], trainable=False)
self.linear_layer = LinearHyperbolic(input_size, self.manifold, self.c, activation=None)
#self.linear_layer = LinearHyperbolic(1433, self.manifold, 1.0, activation=None)
self.aggregation_layer = HyperbolicAggregation(self.manifold, self.c)
self.activation_layer = ActivationHyperbolic(self.manifold, self.c, self.c, activation)

def call(self, inputs):
# Step 1 (hyperbolic feature transform)
x, adj = inputs
# x = self.manifold.logmap0(x, c=self.c)

# Step 2 (attention-based neighborhood aggregation)
print('HGCLayer x shape', x.shape)
x = self.linear_layer(x)
x = self.aggregation_layer((x, adj))

# Step 3 (non-linear activation with different curvatures)
x = self.activation_layer(x)

return x


class HGCNLP(keras.Model):

def __init__(self, input_size, dropout=0.4):
super().__init__()

self.input_size = input_size

self.manifold = Lorentz()
self.c_map = tf.Variable([0.4], trainable=False)
self.c0 = tf.Variable([0.4], trainable=False)
self.c1 = tf.Variable([0.4], trainable=False)
self.c2 = tf.Variable([0.4], trainable=False)

self.conv0 = HGCLayer(self.manifold, self.input_size, self.c0, activation="relu")
self.conv1 = HGCLayer(self.manifold, self.input_size, self.c0, activation="relu")
self.conv2 = HGCLayer(self.manifold, self.input_size, self.c0, activation="relu")

def call(self, inputs):
x, adj = inputs
x_tan = self.manifold.proj_tan0(x, self.c1)
x_hyp = self.manifold.expmap0(x_tan, c=self.c1)
x_hyp = self.manifold.proj(x_hyp, c=self.c1)
# Map euclidean features to Hyperbolic space
#x = self.manifold.expmap0(x, c=self.c_map)
# Stack multiple hyperbolic graph convolution layers
x, adj = self.conv0((x, adj))
#x, adj = self.conv1((x, adj))
#x, adj = self.conv2((x, adj))

# TODO - add link prediction/node classification code as described
# in the notes below
# Notes
# Note 1: Hyperbolic embeddings at the last layer can then be used to predict node attributes or links
# Note 2: For link prediction we use the Fermi-Dirac decoder , a generalization of sigmoid,
# to compute probability scores for edges. We then train HGCN by minimizing the
# cross-entropy loss using negative sampling
# Note 3: For node classification map the output of the last HGCN layer to tangent space of the origin with the
# logarithmic map and then perform Euclidean multinomial logistic regression. Note that another possibility
# is to directly classify points on the hyperboloid manifold using the hyperbolic multinomial logistic loss.
# This method performs similarly to Euclidean classification. Finally, we also add a link prediction
# regularization objective in node classification tasks, to encourage embeddings at the last layer to
# preserve the graph structure
return

def decode(self, emb_in, emb_out):
sqdist = self.manifold.sqdist(emb_in, emb_out, self.c)
# fermi dirac to comput edge probabilities
1. / (tf.exp((sqdist - self.r) / self.t) + 1.0)
return probs
34 changes: 31 additions & 3 deletions hyperlib/nn/layers/lin_hyp.py → hyperlib/nn/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ class LinearHyperbolic(keras.layers.Layer):
Implementation of a hyperbolic linear layer for a neural network, that inherits from the keras Layer class
"""

def __init__(self, units, manifold, c, activation=None, use_bias=True):
def __init__(self, units, manifold, c, use_activation=False, activation=None, use_bias=False):
super().__init__()
self.units = units
self.c = tf.Variable([c], dtype="float64")
self.manifold = manifold
self.activation = keras.activations.get(activation)
self.use_bias = use_bias
self.use_activation = use_activation

def build(self, batch_input_shape):
w_init = tf.random_normal_initializer()
Expand All @@ -40,14 +41,16 @@ def call(self, inputs):
inputs = tf.cast(inputs, tf.float64)
mv = self.manifold.mobius_matvec(self.kernel, inputs, self.c)
res = self.manifold.proj(mv, c=self.c)

print('LinearHyperbolic x shape', inputs.shape)
if self.use_bias:
hyp_bias = self.manifold.expmap0(self.bias, c=self.c)
hyp_bias = self.manifold.proj(hyp_bias, c=self.c)
res = self.manifold.mobius_add(res, hyp_bias, c=self.c)
res = self.manifold.proj(res, c=self.c)

return self.activation(res)
if self.use_activation:
self.activation(res)
return res

def get_config(self):
base_config = super().get_config()
Expand All @@ -58,3 +61,28 @@ def get_config(self):
"manifold": self.manifold,
"curvature": self.c
}

class ActivationHyperbolic(keras.layers.Layer):
def __init__(self, manifold, c_in, c_out, activation):
super().__init__()
self.activation = keras.activations.get(activation)
self.c_in = c_in
self.c_out = c_out
self.manifold = manifold

def build(self, input_shape):
self.built = True

def call(self, inputs):
inputs_tan = self.activation(self.manifold.logmap0(inputs, c=self.c_in))
inputs_tan = self.manifold.proj_tan0(inputs_tan, self.activation(inputs))
out = self.manifold.expmap0(inputs_tan, c=self.c_out)
return self.manifold.proj(out, c=self.c_out)

def get_config(self):
return {
"activation": keras.activations.serialize(self.activation),
"c_in": self.c_in,
"c_out": self.c_out,
"manifold": self.manifold.name,
}
Loading
Loading