Temporal Convolutional Network

Is there an example of wiring up a network similar to GitHub - philipperemy/keras-tcn: Keras Temporal Convolutional Network. in DL4J. I am trying to use this style of network on a sequence of word GloVe embeddings using a setup like below:

ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
    				.seed(12345)
    				.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    				.updater(new Nesterovs()).weightInit(XAVIER).activation(Activation.TANH)
    				.graphBuilder()
    				.addInputs("in")
    				.addLayer("c0", new Convolution1D.Builder(3, 1).convolutionMode(ConvolutionMode.Causal).dilation(1).nIn(300).nOut(256).build(), "in")
    				.addLayer("c1", new Convolution1D.Builder(3, 1).convolutionMode(ConvolutionMode.Causal).dilation(2).nIn(256).nOut(256).build(), "c0")
    				.addLayer("c2", new Convolution1D.Builder(3, 1).convolutionMode(ConvolutionMode.Causal).dilation(4).nIn(256).nOut(256).build(), "c1")
    				.addLayer("c3", new Convolution1D.Builder(3, 1).convolutionMode(ConvolutionMode.Causal).dilation(8).nIn(256).nOut(256).build(), "c2")
    				.addLayer("p2", new GlobalPoolingLayer.Builder(PoolingType.MAX).build(), "c3")
    				.addLayer("batchNorm", new BatchNormalization.Builder().nIn(256).nOut(256).build(), "p2")
    				.addLayer("l0", new DenseLayer.Builder().nIn(256).nOut(256).dropOut(0.5).build(), "batchNorm")
    				.addLayer("l1", new DenseLayer.Builder().nIn(256).nOut(128).build(), "l0")
    				.layer("out",	new OutputLayer.Builder().nIn(DENSE_L1).nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build(),"l1")
    				.setOutputs("out")
    				.build();

But I keep getting errors like the following:

 Failed to execute op conv1d. Attempted to execute with 3 inputs, 1 outputs, 0 targs,0 bargs and 6 iargs. Inputs: [(FLOAT,[200,256,300],c), (FLOAT,[3,256,256],f), (FLOAT,[256],c)]. Outputs: [(FLOAT,[200,256,300],f)]. tArgs: -. iArgs: [3, 1, 0, 4, 2, 0]. bArgs: -. Op own name: "f428232c-01be-4f85-a02a-91a72a8adb6f" - Please see above message (printed out from c++) for a possible cause of error.


java.lang.RuntimeException: Op [conv1d] execution failed
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1723) ~[nd4j-native-1.0.0-beta6.jar:?]
	at org.nd4j.linalg.factory.Nd4j.exec(Nd4j.java:6599) ~[nd4j-api-1.0.0-beta6.jar:1.0.0-beta6]
	at org.deeplearning4j.nn.layers.convolution.Convolution1DLayer.causalConv1dForward(Convolution1DLayer.java:206) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.layers.convolution.Convolution1DLayer.preOutput(Convolution1DLayer.java:161) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.activate(ConvolutionLayer.java:446) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.layers.convolution.Convolution1DLayer.activate(Convolution1DLayer.java:212) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doForward(LayerVertex.java:111) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.ComputationGraph.ffToLayerActivationsInWS(ComputationGraph.java:2136) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1373) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1342) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:170) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:63) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1166) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1116) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1083) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1019) ~[deeplearning4j-nn-1.0.0-beta6.jar:?]
	at com.masked.aiproject.GloveDilatedCNN.trainEpoch(GloveDilatedCNN.java:264) ~[main/:?]
	at com.masked.aiproject.GloveDilatedCNN.handleFile(GloveDilatedCNN.java:205) ~[main/:?]
	at com.masked.aiproject.GloveDilatedCNN.lambda$main$0(GloveDilatedCNN.java:117) [main/:?]
	at com.google.common.util.concurrent.TrustedListenableFutureTask$TrustedFutureInterruptibleTask.runInterruptibly(TrustedListenableFutureTask.java:125) [guava-26.0-jre.jar:?]
	at com.google.common.util.concurrent.InterruptibleTask.run(InterruptibleTask.java:57) [guava-26.0-jre.jar:?]
	at com.google.common.util.concurrent.TrustedListenableFutureTask.run(TrustedListenableFutureTask.java:78) [guava-26.0-jre.jar:?]
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128) [?:?]
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628) [?:?]
	at java.lang.Thread.run(Thread.java:834) [?:?]
Caused by: java.lang.RuntimeException: could not create a dilated convolution forward descriptor
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:2019) ~[nd4j-native-1.0.0-beta6.jar:?]
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1713) ~[nd4j-native-1.0.0-beta6.jar:?]
	... 25 more

Does anyone have an example or any tips on how to debug this?

Edit: Added formatting so it is easier to read the post.

I think the example you are looking for is CNN Sentence Classification

It iterates over a sequence of W2V embeddings.

Also, was there anything else printed above this exception? The C++ error messages go straight to STDERR if I remember correctly and usually have more information about what exactly may be wrong in your setup.

Thanks for the example I will take a look.

As a side note, if I remove dilation(x) this setup works fine.

I am rerunning again to look closer for the C++ error message.

I am able to reproduce my error on the CnnSentenceClassificationExample. If I change the network setup to

 ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
                .weightInit(WeightInit.RELU)
                .activation(Activation.LEAKYRELU)
                .updater(new Adam(0.01))
                .graphBuilder()
                .addInputs("input")
                .addLayer("c0", new Convolution1D.Builder(3, 1 ).convolutionMode(ConvolutionMode.Causal).nIn(300).nOut(128).build(), "input")
                .addLayer("c1", new Convolution1D.Builder(3, 1).convolutionMode(ConvolutionMode.Causal).dilation(2).nIn(128).nOut(128).build(), "c0")
                .addLayer("c2", new Convolution1D.Builder(3, 1).convolutionMode(ConvolutionMode.Causal).dilation(4).nIn(128).nOut(128).build(), "c1")
                .addLayer("c3", new Convolution1D.Builder(3, 1).convolutionMode(ConvolutionMode.Causal).dilation(8).nIn(128).nOut(128).build(), "c2")
                .addLayer("p2", new GlobalPoolingLayer.Builder(PoolingType.MAX).build(), "c3")
                .addLayer("batchNorm", new BatchNormalization.Builder().nIn(128).nOut(128).build(), "p2")
                .addLayer("l0", new DenseLayer.Builder().nIn(128).nOut(128).build(), "batchNorm")
                .addLayer("l1", new DenseLayer.Builder().nIn(128).nOut(128).build(), "l0")
                .addLayer("out", new OutputLayer.Builder()
                        .lossFunction(LossFunctions.LossFunction.MCXENT)
                        .activation(Activation.SOFTMAX)
                        .nIn(128)
                        .nOut(2)    //2 classes: positive or negative
                        .build(), "l1")
                .setOutputs("out")
                .build();

And change the iterator to CNN1D

CnnSentenceDataSetIterator.Builder(Format.CNN1D)
                .sentenceProvider(sentenceProvider)
                .wordVectors(wordVectors)
                .minibatchSize(minibatchSize)
                .maxSentenceLength(maxSentenceLength)
                .useNormalizedWordVectors(false)
                .build();

It explodes. It works fine when I remove the dilation.

I don’t see anything useful on standard err other than what was posted above.

2020-02-05 15:30:12,996 [main] <INFO> factory.Nd4jBackend: Loaded [CpuBackend] backend
2020-02-05 15:30:13,383 [main] <INFO> nativeblas.NativeOpsHolder: Number of threads used for linear algebra: 12
2020-02-05 15:30:13,419 [main] <INFO> nativeblas.Nd4jBlas: Number of threads used for OpenMP BLAS: 12
2020-02-05 15:30:13,422 [main] <INFO> executioner.DefaultOpExecutioner: Backend used: [CPU]; OS: [Linux]
2020-02-05 15:30:13,422 [main] <INFO> executioner.DefaultOpExecutioner: Cores: [24]; Memory: [30.0GB];
2020-02-05 15:30:13,422 [main] <INFO> executioner.DefaultOpExecutioner: Blas vendor: [OPENBLAS]
2020-02-05 15:30:13,450 [main] <INFO> graph.ComputationGraph: Starting ComputationGraph with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
Number of parameters by layer:
	c0	115328
	c1	49280
	c2	49280
	c3	49280
	p2	0
	batchNorm	512
	l0	16512
	l1	16512
	out	258
Loading word vectors and creating DataSetIterators
Starting training
2020-02-05 15:31:08,962 [main] <ERROR> ops.NativeOpExecutioner: Failed to execute op conv1d. Attempted to execute with 3 inputs, 1 outputs, 0 targs,0 bargs and 6 iargs. Inputs: [(FLOAT,[32,128,256],c), (FLOAT,[3,128,128],f), (FLOAT,[128],c)]. Outputs: [(FLOAT,[32,128,256],f)]. tArgs: -. iArgs: [3, 1, 0, 4, 2, 0]. bArgs: -. Op own name: "a9b9d1b1-bea5-46f8-9b71-4f40f5d07dd6" - Please see above message (printed out from c++) for a possible cause of error.
Exception in thread "main" java.lang.RuntimeException: Op [conv1d] execution failed
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1723)
	at org.nd4j.linalg.factory.Nd4j.exec(Nd4j.java:6599)
	at org.deeplearning4j.nn.layers.convolution.Convolution1DLayer.causalConv1dForward(Convolution1DLayer.java:206)
	at org.deeplearning4j.nn.layers.convolution.Convolution1DLayer.preOutput(Convolution1DLayer.java:161)
	at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.activate(ConvolutionLayer.java:446)
	at org.deeplearning4j.nn.layers.convolution.Convolution1DLayer.activate(Convolution1DLayer.java:212)
	at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doForward(LayerVertex.java:111)
	at org.deeplearning4j.nn.graph.ComputationGraph.ffToLayerActivationsInWS(ComputationGraph.java:2136)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1373)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1342)
	at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:170)
	at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:63)
	at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
	at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1166)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1116)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1083)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1019)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1007)
	at org.deeplearning4j.examples.convolution.sentenceclassification.CnnSentenceClassificationExample.main(CnnSentenceClassificationExample.java:137)
Caused by: java.lang.RuntimeException: could not create a dilated convolution forward descriptor
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:2019)
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1713)
	... 18 more
Caused by: java.lang.RuntimeException: could not create a dilated convolution forward descriptor


> Task :analyses:CnnSentenceClassificationExample.main() FAILED

Execution failed for task ':analyses:CnnSentenceClassificationExample.main()'.
> Process 'command '/home/mdavis/.local/share/JetBrains/Toolbox/apps/IDEA-U/ch-0/193.5662.53/jbr/bin/java'' finished with non-zero exit value 1

Can you please disable mkldnn use, and run again?

You can do it with this line of code:

Nd4j.getEnvironment().allowHelpers(false);

Just add it at the top of your code

@mdavis95 @raver119 Unfortunately Nd4j.getEnvironment() was added after 1.0.0-beta6 was released.

You can use this instead, which works for 1.0.0-beta6 and earlier:
Nd4jCpu.Environment.getInstance().allowHelpers(false);

Ouch! :slight_smile:

Well, same idea. Lets disable MKL-DNN and see what happens.

Yep, that makes it work.

I see, thank you. We’ll investigate why MKL-DNN forbids that. It might be their natural limitation.

Would this be likely be fine on a GPU?

According to the documentation cuDNN supports dilation, so I expect this to be fine on a GPU.

It works great on GPU thanks! I appreciate the help getting this up and running.

I tried to track down MKL-DNN support for features and only found this check:

Debugging that is outside my league but I figured I would pass on the link on the off chance it helps.