diff --git a/README.md b/README.md
index a3f6868..68b2875 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,8 @@
+# TF2-GrouPy
+
+A Tensorflow 2 port for GrouPy. Best used alongside [tf2-keras-gcnn](https://github.com/neel-dey/tf2-keras-gcnn). Note: this repo includes [group-to-Z2 indexing from nom](https://github.com/nom/GrouPy/commit/7b1128b6cb0d33e5733667f8e07490ea1d44a7dc).
+
+Original README follows below with a tf2-compatible minimal-working-example.
### Note: If you are looking for a PyTorch implementation, please have a look at the pull requests by Jorn Peters and Adam Bielski (https://github.com/tscohen/GrouPy/pulls).
@@ -41,34 +46,37 @@ $ nosetests -v
### TensorFlow
-```
+```python
import numpy as np
import tensorflow as tf
+
+tf.compat.v1.disable_eager_execution()
+
from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util
# Construct graph
-x = tf.placeholder(tf.float32, [None, 9, 9, 3])
+x = tf.compat.v1.placeholder(tf.float32, [None, 9, 9, 3])
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input='Z2', h_output='D4', in_channels=3, out_channels=64, ksize=3)
-w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
+w = tf.Variable(tf.random.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input='D4', h_output='D4', in_channels=64, out_channels=64, ksize=3)
-w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
+w = tf.Variable(tf.random.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
# Compute
-init = tf.global_variables_initializer()
-sess = tf.Session()
+init = tf.compat.v1.global_variables_initializer()
+sess = tf.compat.v1.Session()
sess.run(init)
y = sess.run(y, feed_dict={x: np.random.randn(10, 9, 9, 3)})
sess.close()
-print y.shape # (10, 9, 9, 512)
+print(y.shape) # (10, 9, 9, 512)
```
### Chainer
@@ -157,4 +165,4 @@ These subclasses can easily be tested against the group axioms and other mathema
## References
-1. T.S. Cohen, M. Welling, [Group Equivariant Convolutional Networks](http://www.jmlr.org/proceedings/papers/v48/cohenc16.pdf). Proceedings of the International Conference on Machine Learning (ICML), 2016.
\ No newline at end of file
+1. T.S. Cohen, M. Welling, [Group Equivariant Convolutional Networks](http://www.jmlr.org/proceedings/papers/v48/cohenc16.pdf). Proceedings of the International Conference on Machine Learning (ICML), 2016.
diff --git a/groupy/gconv/tensorflow_gconv/splitgconv2d.py b/groupy/gconv/tensorflow_gconv/splitgconv2d.py
index 644df85..3abbcdc 100644
--- a/groupy/gconv/tensorflow_gconv/splitgconv2d.py
+++ b/groupy/gconv/tensorflow_gconv/splitgconv2d.py
@@ -12,7 +12,6 @@ def gconv2d(input, filter, strides, padding, gconv_indices, gconv_shape_info,
Tensorflow implementation of the group convolution.
This function has the same interface as the standard convolution nn.conv2d, except for two new parameters,
gconv_indices and gconv_shape_info. These can be obtained from gconv2d_util(), and are described below
-
:param input: a tensor with (batch, height, width, in channels) axes.
:param filter: a tensor with (ksize, ksize, in channels * in transformations, out channels) axes.
The shape for filter can be obtained from gconv2d_util().
@@ -37,7 +36,7 @@ def gconv2d(input, filter, strides, padding, gconv_indices, gconv_shape_info,
transformed_filter = transform_filter_2d_nhwc(w=filter, flat_indices=gconv_indices, shape_info=gconv_shape_info)
# Convolve input with transformed filters
- conv = tf.nn.conv2d(input=input, filter=transformed_filter, strides=strides, padding=padding,
+ conv = tf.compat.v1.nn.conv2d(input=input, filter=transformed_filter, strides=strides, padding=padding,
use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format, name=name)
return conv
@@ -50,9 +49,8 @@ def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
1) an array of indices used in the filter transformation step of gconv2d
2) shape information required by gconv2d
5) the shape of the filter tensor to be allocated and passed to gconv2d
-
:param h_input: one of ('Z2', 'C4', 'D4'). Use 'Z2' for the first layer. Use 'C4' or 'D4' for later layers.
- :param h_output: one of ('C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
+ :param h_output: one of ('Z2', 'C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
The choice of h_output of one layer should equal h_input of the next layer.
:param in_channels: the number of input channels. Note: this refers to the number of (3D) channels on the group.
The number of 2D channels will be 1, 4, or 8 times larger, depending the value of h_input.
@@ -78,10 +76,22 @@ def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
gconv_indices = flatten_indices(make_d4_p4m_indices(ksize=ksize))
nti = 8
nto = 8
+ elif h_input == 'D4' and h_output == 'Z2':
+ gconv_indices = flatten_indices(make_d4_z2_indices(ksize=ksize))
+ nti = 8
+ nto = 1
+ elif h_input == 'C4' and h_output == 'Z2':
+ gconv_indices = flatten_indices(make_c4_z2_indices(ksize=ksize))
+ nti = 4
+ nto = 1
else:
raise ValueError('Unknown (h_input, h_output) pair:' + str((h_input, h_output)))
- w_shape = (ksize, ksize, in_channels * nti, out_channels)
+ if h_output == 'Z2':
+ w_shape = (ksize, ksize, in_channels, out_channels)
+ else:
+ w_shape = (ksize, ksize, in_channels * nti, out_channels)
+
gconv_shape_info = (out_channels, nto, in_channels, nti, ksize)
return gconv_indices, gconv_shape_info, w_shape
@@ -94,7 +104,6 @@ def gconv2d_addbias(input, bias, nti=8):
A G-feature map usually consists of a number (e.g. 4 or 8) adjacent channels.
This function will add a single bias vector to a stack of feature maps that has e.g. 4 or 8 times more 2D channels
than G-channels, by replicating the bias across adjacent groups of 2D channels.
-
:param input: tensor of shape (n, h, w, ni * nti), where n is the batch dimension, (h, w) are the height and width,
ni is the number of input G-channels, and nti is the number of transformations in H.
:param bias: tensor of shape (ni,)
@@ -103,3 +112,4 @@ def gconv2d_addbias(input, bias, nti=8):
"""
# input = tf.reshape(input, ())
pass # TODO
+