One-hot transformations not working as expected

I am attempting to produce a RecordReaderMultiDataSetIterator incorporating a variety of transformed feature types, some of which being one-hot-encoded classes. However, the functions integerToOneHot and categoricalToOneHot are not behaving as I would expect.

The example code below includes these six features, along with their corresponding transformations:

  • a - integer
  • b - integer > one-hot
  • c - string > integer
  • d - string > integer > one-hot
  • e - categorical > one-hot
  • f - string > categorical > one-hot

While integer features a and c produce the expected output, the integerToOneHot outputs for b and d all have the 0 bit flagged. For the categoricalToOneHot outputs of e and f, I also get unexpected results, as if the inputs had been (“b”, “a”, “a”, “a”) and (“a”, “b”, “a”, “a”), respectively.

Am I doing something wrong here? Example Kotlin code:

fun oneHotTest(): MultiDataSet {

    val classes = listOf("a", "b", "c", "d")
    val colAValues = listOf(1, 2, 3, 4)
    val colBValues = listOf(2, 3, 4, 5)
    val colCValues = listOf("3", "4", "5", "6")
    val colDValues = listOf("4", "5", "6", "7")
    val colEValues = listOf("a", "d", "b", "b")
    val colFValues = listOf("d", "a", "c", "c")

    val data = (0..3).map { listOf(
            IntWritable(colAValues[it]),
            IntWritable(colBValues[it]),
            Text(colCValues[it]),
            Text(colDValues[it]),
            Text(colEValues[it]),
            Text(colFValues[it])
    )}
    val reader = CollectionRecordReader(data)

    val schema = Schema.Builder()
            .addColumnInteger("a")
            .addColumnInteger("b")
            .addColumnString("c")
            .addColumnString("d")
            .addColumnCategorical("e", classes)
            .addColumnString("f")
            .build()

    val transformProcess = TransformProcess.Builder(schema)
            .integerToOneHot("b", 0, 9)
            .convertToInteger("c")
            .convertToInteger("d")
            .integerToOneHot("d", 0, 9)
            .categoricalToOneHot("e")
            .stringToCategorical("f", classes)
            .categoricalToOneHot("f")
            .build()

    val transformReader = TransformProcessRecordReader(reader, transformProcess)

    val readerName = "xReader"
    val dataIterator = RecordReaderMultiDataSetIterator.Builder(4)
            .addReader(readerName, transformReader)
            .addInput(readerName, 0, 0)
            .addInputOneHot(readerName, 1, 10)
            .addInput(readerName, 11, 11)
            .addInputOneHot(readerName, 12, 10)
            .addInputOneHot(readerName, 22, 4)
            .addInputOneHot(readerName, 26, 4)
            .build()

    return dataIterator.next()
}

For anyone else attempting to use this feature, my mistake is that addInputOneHot isn’t expecting a one-hot input, but just a class number which it converts to one-hot format itself. So, the conversion to one-hot format in the transformation step was redundant. The working version of the code above is:

fun oneHotTest(): MultiDataSet {

    val classes = listOf("a", "b", "c", "d")
    val colAValues = listOf(1, 2, 3, 4)
    val colBValues = listOf(2, 3, 4, 5)
    val colCValues = listOf("3", "4", "5", "6")
    val colDValues = listOf("4", "5", "6", "7")
    val colEValues = listOf("a", "d", "b", "b")
    val colFValues = listOf("d", "a", "c", "c")

    val data = (0..3).map { listOf(
            IntWritable(colAValues[it]),
            IntWritable(colBValues[it]),
            Text(colCValues[it]),
            Text(colDValues[it]),
            Text(colEValues[it]),
            Text(colFValues[it])
    )}
    val reader = CollectionRecordReader(data)

    val schema = Schema.Builder()
            .addColumnInteger("a")
            .addColumnInteger("b")
            .addColumnString("c")
            .addColumnString("d")
            .addColumnCategorical("e", classes)
            .addColumnString("f")
            .build()

    val transformProcess = TransformProcess.Builder(schema)
            .convertToInteger("c")
            .convertToInteger("d")
            .categoricalToInteger("e")
            .stringToCategorical("f", classes)
            .categoricalToInteger("f")
            .build()

    val transformReader = TransformProcessRecordReader(reader, transformProcess)

    val readerName = "xReader"
    val dataIterator = RecordReaderMultiDataSetIterator.Builder(4)
            .addReader(readerName, transformReader)
            .addInput(readerName, 0, 0)
            .addInputOneHot(readerName, 1, 10)
            .addInput(readerName, 2, 2)
            .addInputOneHot(readerName, 3, 10)
            .addInputOneHot(readerName, 4, 4)
            .addInputOneHot(readerName, 5, 4)
            .build()

    return dataIterator.next()
}