I am attempting to produce a RecordReaderMultiDataSetIterator
incorporating a variety of transformed feature types, some of which being one-hot-encoded classes. However, the functions integerToOneHot
and categoricalToOneHot
are not behaving as I would expect.
The example code below includes these six features, along with their corresponding transformations:
-
a
- integer -
b
- integer > one-hot -
c
- string > integer -
d
- string > integer > one-hot -
e
- categorical > one-hot -
f
- string > categorical > one-hot
While integer features a
and c
produce the expected output, the integerToOneHot
outputs for b
and d
all have the 0 bit flagged. For the categoricalToOneHot
outputs of e
and f
, I also get unexpected results, as if the inputs had been (“b”, “a”, “a”, “a”) and (“a”, “b”, “a”, “a”), respectively.
Am I doing something wrong here? Example Kotlin code:
fun oneHotTest(): MultiDataSet {
val classes = listOf("a", "b", "c", "d")
val colAValues = listOf(1, 2, 3, 4)
val colBValues = listOf(2, 3, 4, 5)
val colCValues = listOf("3", "4", "5", "6")
val colDValues = listOf("4", "5", "6", "7")
val colEValues = listOf("a", "d", "b", "b")
val colFValues = listOf("d", "a", "c", "c")
val data = (0..3).map { listOf(
IntWritable(colAValues[it]),
IntWritable(colBValues[it]),
Text(colCValues[it]),
Text(colDValues[it]),
Text(colEValues[it]),
Text(colFValues[it])
)}
val reader = CollectionRecordReader(data)
val schema = Schema.Builder()
.addColumnInteger("a")
.addColumnInteger("b")
.addColumnString("c")
.addColumnString("d")
.addColumnCategorical("e", classes)
.addColumnString("f")
.build()
val transformProcess = TransformProcess.Builder(schema)
.integerToOneHot("b", 0, 9)
.convertToInteger("c")
.convertToInteger("d")
.integerToOneHot("d", 0, 9)
.categoricalToOneHot("e")
.stringToCategorical("f", classes)
.categoricalToOneHot("f")
.build()
val transformReader = TransformProcessRecordReader(reader, transformProcess)
val readerName = "xReader"
val dataIterator = RecordReaderMultiDataSetIterator.Builder(4)
.addReader(readerName, transformReader)
.addInput(readerName, 0, 0)
.addInputOneHot(readerName, 1, 10)
.addInput(readerName, 11, 11)
.addInputOneHot(readerName, 12, 10)
.addInputOneHot(readerName, 22, 4)
.addInputOneHot(readerName, 26, 4)
.build()
return dataIterator.next()
}