Hi, I was sent here from StackOverflow.
I’m trying to implement the neural network described in this paper: https://journals.sdu.edu.kz/index.php/ais/article/view/410
I looked at the official examples and tried to adapt the AdditionModelWithSeq2Seq to match my requirements, but I’m neither sure how to read in my data-file correctly, nor do I know if the computation-graph or the multilayer-network is the right approach to use here.
My inputs are 32Bit numbers (n) that are the product of 2 multiplied primes (p,q). The input numbers should be used as binary vectors and the network should output p.
The network described in the paper looks like this:
(LSTM
Batch Normalization
Dropout) x 3
(Dense
Batch Normalization
Dropout) x 2
Dense
My current approach for the computation graph looks like this:
ArrayList<String> inAndOutNames = new ArrayList<>();
String[] inputNames = new String[inputAmount];
InputType[] inputTypes = new InputType[inputAmount + 1];
for(int i = 0; i < inputAmount; i++)
{
inAndOutNames.add("bit" + i);
inputNames[i] = "bit" + i;
inputTypes[i] = InputType.recurrent(1);
}
inAndOutNames.add("p");
inputTypes[inputAmount] = InputType.recurrent(1);
ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.001))
.seed(seed)
.graphBuilder()
.addInputs(inAndOutNames)
.setInputTypes(inputTypes)
.addLayer("l0", new DenseLayer.Builder().nIn(inputAmount).nOut(inputAmount).build(), inputNames)
.addLayer("l1", new LSTM.Builder().nIn(inputAmount).nOut(128).activation(Activation.TANH).build(), "l0")
.addLayer("l2", new LSTM.Builder().nIn(128).nOut(256).build(), "l1")
.addLayer("l3", new DenseLayer.Builder().nIn(256).nOut(256).build(), "l2", "p")
.addLayer("lOut", new DenseLayer.Builder().nIn(256).nOut(10).build(), "l3")
.setOutputs("lOut")
.build();
model = new ComputationGraph(configuration);
This is my attempt to read in my test data:
RecordReader recordReader;
recordReader = new CSVRecordReader(0, ",");
recordReader.initialize(new FileSplit(new File("src/main/resources/datasets/32-Bit x 10000 RSA Bit-Combinations (2023-03-20-12-45-16)")));
dl.setInputAmount(bits);
MultiDataSetIterator dataSetIterator = new NewCustomPrimeIterator(bits);
dataSetIterator = new RecordReaderMultiDataSetIterator.Builder(100)
.addReader("reader", recordReader)
.addInput("reader", 0, 31)
.addOutput("reader", 32, 32)
.build();