Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3575f54
lightgbm: Add SoftLabelParamParserUtil
Dec 15, 2022
3a320d2
chore(lightgbm): move SchemaFieldsUtil to lightgbm package
May 13, 2026
4b743e3
feat(lightgbm): added utility class to parse lightgbm sample weight t…
May 13, 2026
b6b0ce3
feat(lightgbm): add sample weight parameter to lightgbm descriptor
May 13, 2026
8ab146c
feat(lightgbm): modify training to make use of sample weight, when pr…
May 13, 2026
74ceac0
chore(lightgbm): updated tests to validate sample weight
May 13, 2026
629ced4
chore(cicd): bump cache maven packages
May 14, 2026
496f967
fix(submodules): pointing make-lightgbm submodule to fork with AMD64 …
May 14, 2026
690313c
fix(submodules): update make-lightgbm submodule
May 14, 2026
bcb3b8e
fix(submodules): update make-lightgbm submodule
May 14, 2026
241ec89
fix(submodules): update make-lightgbm submodule
May 14, 2026
905a26f
fix(pom): pin maven-antrun-plugin version
May 14, 2026
d1ce4d4
chore(cicd): bumping codecod to v5
May 15, 2026
47faf1f
fix(lightgbm): add check for sample weight values
May 15, 2026
35bdb1b
chore(tests): added tests for SchemaFieldsUtil class
May 15, 2026
817ecec
chore(tests): add test for negative sample weights
May 15, 2026
b23a131
fix(submodules): update make-lightgbm submodule
May 15, 2026
d3e0230
chore(tests): added tests to cover untested lines
May 18, 2026
f7fa44e
fix(submodules): pointing make-lightgbm to latest upstream
May 18, 2026
95a6836
chore(fairgbm-descriptor): fix link to feedzai's fairgbm documentation
May 18, 2026
d0b10d0
chore(lint): fix variable name and added missing parameter to method …
May 18, 2026
dde15e0
fix(lightgbm): simplified logic to determine the number of features, …
May 18, 2026
f6d1c1a
chore(lightgbm): added method to retrieve the relevant schema for loa…
May 19, 2026
9013fd1
chore(lightgbm): updated javadoc and parameter names in LightGBMModel…
May 19, 2026
5b7b9be
chore(lightgbm): unified logic to sanitize field names (replace space…
May 19, 2026
0ba4b77
fix(lightgbm): fixed validations for model's predictive fields - ensu…
May 19, 2026
c515b98
chore(lightgbm): fix constructor argument
May 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
distribution: 'zulu'

- name: Cache Maven packages
uses: actions/cache@v3.3.2
uses: actions/cache@v4
with:
path: ~/.m2
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
Expand Down Expand Up @@ -70,6 +70,8 @@ jobs:
cd /feedzai-openml-java && \
mvn test -B -Dsurefire.failIfNoSpecifiedTests=false -Dtest=!ClassifyUnknownCategoryTest#test,!H2OModelProviderTrainTest#trainModelsForAllAlgorithms'

- uses: codecov/codecov-action@v1
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
use_oidc: true
fail_ci_if_error: true
5 changes: 3 additions & 2 deletions openml-lightgbm/lightgbm-builder/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,16 @@

<plugin>
<artifactId>maven-antrun-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<!-- this should happen after exec-maven-plugin -->
<phase>generate-resources</phase>
<configuration>
<tasks>
<target>
<echo message="unzipping file" />
<unzip src="${basedir}/make-lightgbm/build/lightgbmlib.jar" dest="${project.build.directory}/classes" />
</tasks>
</target>
</configuration>
<goals>
<goal>run</goal>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public enum LightGBMAlgorithms implements MLAlgorithmEnum {
"FairGBM (LightGBM with Fairness)",
FairGBMDescriptorUtil.PARAMS,
MachineLearningAlgorithmType.SUPERVISED_BINARY_CLASSIFICATION,
"https://lightgbm.readthedocs.io/" // TODO: link to our documentation
"https://cam.feedzai.com/pulse/latest/docs/common/pulse/understanding-pulse/core-concepts/models/model-training/feedzai-fairgbm"
)),
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;

import com.google.common.collect.ImmutableSet;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_float;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_int;
Expand All @@ -43,6 +44,7 @@

import static com.feedzai.openml.provider.lightgbm.FairGBMDescriptorUtil.CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME;
import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME;
import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SAMPLE_WEIGHT_COL_PARAMETER_NAME;
import static java.lang.Integer.parseInt;
import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -129,7 +131,16 @@ static void fit(final Dataset dataset,
final long instancesPerChunk) {

final DatasetSchema schema = dataset.getSchema();
final int numFeatures = schema.getPredictiveFields().size();

final Optional<String> sampleWeightFieldName = SampleWeightParamParserUtil.getSampleWeightFieldName(params);
int numFeatures = schema.getPredictiveFields().size();

// Check if the weight field exists and is explicitly part of the predictive fields
final Optional<Integer> sampleWeightColIndex =
SampleWeightParamParserUtil.getSampleWeightColumnIndex(params, schema);
if (sampleWeightColIndex.isPresent()) {
numFeatures--;
}

// Parse train parameters to LightGBM format
final String trainParams = getLightGBMTrainParamsString(params, schema);
Expand All @@ -140,13 +151,22 @@ static void fit(final Dataset dataset,
final SWIGTrainData swigTrainData = new SWIGTrainData(
numFeatures,
instancesPerChunk,
FairGBMParamParserUtil.isFairnessConstrained(params));
FairGBMParamParserUtil.isFairnessConstrained(params),
sampleWeightColIndex.isPresent()
);
final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster();

/// Create LightGBM dataset
final int constraintGroupColIndex = FairGBMParamParserUtil.getConstraintGroupColumnIndex(params, schema).orElse(
FairGBMParamParserUtil.NO_SPECIFIC);
createTrainDataset(dataset, numFeatures, trainParams, constraintGroupColIndex, swigTrainData);
createTrainDataset(
dataset,
numFeatures,
trainParams,
constraintGroupColIndex,
sampleWeightColIndex,
swigTrainData
);

/// Create Booster from dataset
createBoosterStructure(swigTrainBooster, swigTrainData, trainParams);
Expand Down Expand Up @@ -203,6 +223,7 @@ private static void createTrainDataset(final Dataset dataset,
final int numFeatures,
final String trainParams,
final int constraintGroupColIndex,
final Optional<Integer> sampleWeightColIndex,
final SWIGTrainData swigTrainData) {

logger.info("Creating LightGBM dataset");
Expand All @@ -211,7 +232,8 @@ private static void createTrainDataset(final Dataset dataset,
copyTrainDataToSWIGArrays(
dataset,
swigTrainData,
constraintGroupColIndex
constraintGroupColIndex,
sampleWeightColIndex
);

initializeLightGBMTrainDatasetFeatures(
Expand All @@ -224,13 +246,17 @@ private static void createTrainDataset(final Dataset dataset,
swigTrainData
);

if (sampleWeightColIndex.isPresent()) {
setLightGBMDatasetSampleWeightData(swigTrainData);
}

if (constraintGroupColIndex != FairGBMParamParserUtil.NO_SPECIFIC) {
setLightGBMDatasetConstraintGroupData(
swigTrainData
);
}

setLightGBMDatasetFeatureNames(swigTrainData.swigDatasetHandle, dataset.getSchema());
setLightGBMDatasetFeatureNames(swigTrainData.swigDatasetHandle, dataset.getSchema(), sampleWeightColIndex);

logger.info("Created LightGBM dataset.");
}
Expand Down Expand Up @@ -308,6 +334,33 @@ private static SWIGTYPE_p_int genSWIGFeatureChunkSizesArray(final SWIGTrainData
return swigChunkSizesArray;
}

/**
* Sets the LightGBM dataset sample weight data.
*
* @param swigTrainData SWIGTrainData object.
*/
private static void setLightGBMDatasetSampleWeightData(final SWIGTrainData swigTrainData) {
final long numInstances = swigTrainData.swigSampleWeightsChunkedArray.get_add_count();
// Init SWIG array and copy from chunked data.
SWIGTYPE_p_float swigSampleWeightsData = swigTrainData.coalesceChunkedSwigSampleWeightDataArray();
logger.debug("FTL: #weights={}", numInstances);

logger.debug("Setting sample weight data...");
final int returnCodeLGBM = lightgbmlib.LGBM_DatasetSetField(
swigTrainData.swigDatasetHandle,
"weight", // LightGBM weight column type.
lightgbmlib.float_to_voidp_ptr(swigSampleWeightsData),
(int) numInstances,
lightgbmlibConstants.C_API_DTYPE_FLOAT32
);
if (returnCodeLGBM == -1) {
logger.error("Could not set sample weight data.");
throw new LightGBMException();
}

swigTrainData.destroySwigSampleWeightsDataArray();
}

/**
* Sets the LightGBM dataset label data.
*
Expand Down Expand Up @@ -370,11 +423,16 @@ private static void setLightGBMDatasetConstraintGroupData(final SWIGTrainData sw
* @param swigDatasetHandle SWIG dataset handle
* @param schema Dataset schema
*/
private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDatasetHandle, final DatasetSchema schema) {
private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDatasetHandle,
final DatasetSchema schema,
final Optional<Integer> sampleWeightColIndex) {

final int numFeatures = schema.getPredictiveFields().size();
final List<FieldSchema> featureFields = schema.getPredictiveFields().stream()
.filter(field -> !sampleWeightColIndex.equals(Optional.of(field.getFieldIndex())))
.collect(toList());
final int numFeatures = featureFields.size();

final String[] featureNames = getFieldNames(schema.getPredictiveFields());
final String[] featureNames = getFieldNames(featureFields);
logger.debug("featureNames {}", Arrays.toString(featureNames));

final int returnCodeLGBM = lightgbmlib.LGBM_DatasetSetFeatureNames(swigDatasetHandle, featureNames, numFeatures);
Expand Down Expand Up @@ -482,12 +540,13 @@ static void saveModelFileToDisk(final SWIGTYPE_p_void swigBoosterHandle, final P
*/
private static void copyTrainDataToSWIGArrays(final Dataset dataset,
final SWIGTrainData swigTrainData) {
copyTrainDataToSWIGArrays(dataset, swigTrainData, FairGBMParamParserUtil.NO_SPECIFIC);
copyTrainDataToSWIGArrays(dataset, swigTrainData, FairGBMParamParserUtil.NO_SPECIFIC, Optional.empty());
}

private static void copyTrainDataToSWIGArrays(final Dataset dataset,
final SWIGTrainData swigTrainData,
final int constraintGroupIndex) {
final int constraintGroupIndex,
final Optional<Integer> sampleWeightColIndex) {

final DatasetSchema datasetSchema = dataset.getSchema();
final int numFields = datasetSchema.getFieldSchemas().size();
Expand All @@ -496,6 +555,7 @@ private static void copyTrainDataToSWIGArrays(final Dataset dataset,
ValidationUtils' validateCategoricalSchema:
*/
final int targetIndex = datasetSchema.getTargetIndex().get();
final int sampleWeightIdx = sampleWeightColIndex.orElse(-1);

final Iterator<Instance> iterator = dataset.getInstances();
while (iterator.hasNext()) {
Expand All @@ -507,8 +567,21 @@ private static void copyTrainDataToSWIGArrays(final Dataset dataset,
swigTrainData.addConstraintGroupValue((int) instance.getValue(constraintGroupIndex));
}

// Add sample weight and validate that it is non-negative
sampleWeightColIndex.ifPresent(integer -> {
final float weight = (float) instance.getValue(integer);
if (weight < 0) {
throw new IllegalArgumentException(String.format(
"Sample weight must be non-negative, but got: %f", weight
));
}
swigTrainData.addSampleWeightValue(weight);
});

for (int colIdx = 0; colIdx < numFields; ++colIdx) {
if (colIdx != targetIndex) {
// Don't add features for target and sample weight columns (in the case of sample weight,
// only if this is defined)
if ((colIdx != targetIndex) && (colIdx != sampleWeightIdx)) {
swigTrainData.addFeatureValue(instance.getValue(colIdx));
}
}
Expand All @@ -522,6 +595,11 @@ private static void copyTrainDataToSWIGArrays(final Dataset dataset,
swigTrainData.swigLabelsChunkedArray.get_add_count();
}

if (swigTrainData.useSampleWeight) {
assert swigTrainData.swigSampleWeightsChunkedArray.get_add_count() ==
swigTrainData.swigLabelsChunkedArray.get_add_count();
}

logger.debug("Copied train data of size {} into {} chunks.",
swigTrainData.swigLabelsChunkedArray.get_add_count(),
swigTrainData.swigLabelsChunkedArray.get_chunks_count()
Expand Down Expand Up @@ -559,8 +637,11 @@ private static String getLightGBMTrainParamsString(final Map<String, String> map
constraintGroupColIdx.ifPresent(
integer -> preprocessedMapParams.put(CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME, Integer.toString(integer)));

// Add all **other** parameters
mapParams.forEach(preprocessedMapParams::putIfAbsent);
// Add all **other** parameters - remove sample weight parameter as it is passed to LightGBM in a different way
// (via lightgbmlib.LGBM_DatasetSetField(..., "weight", ...))
mapParams.entrySet().stream()
.filter(e -> !e.getKey().equals(SAMPLE_WEIGHT_COL_PARAMETER_NAME))
.forEach(e -> preprocessedMapParams.putIfAbsent(e.getKey(), e.getValue()));

// Build string containing params in LightGBM format
final StringBuilder paramsBuilder = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.feedzai.openml.provider.descriptor.ModelParameter;
import com.feedzai.openml.provider.descriptor.fieldtype.BooleanFieldType;
import com.feedzai.openml.provider.descriptor.fieldtype.ChoiceFieldType;
import com.feedzai.openml.provider.descriptor.fieldtype.FreeTextFieldType;

import com.google.common.collect.ImmutableSet;

import java.util.Set;
Expand Down Expand Up @@ -52,6 +54,11 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
*/
public static final String BAGGING_FREQUENCY_PARAMETER_NAME = "bagging_freq";

/**
* Sample weight parameter name.
*/
public static final String SAMPLE_WEIGHT_COL_PARAMETER_NAME = "sample_weight";

/**
* Global seed parameter name.
*/
Expand Down Expand Up @@ -347,8 +354,17 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
"Set to true if training data is unbalanced. \nWhilst enabling this should increase the overall performance metric of the model, it will also result in poor estimates of the individual class probabilities. Cannot be used at the same time as 'scale_pos_weight'.", // TODO nam parameter in ui (scale_pos_weight)
NOT_MANDATORY,
new BooleanFieldType(false)
)
),
// TODO: https://lightgbm.readthedocs.io/en/latest/Parameters.html#scale_pos_weight ?? would require setting the pos label

new ModelParameter(
SAMPLE_WEIGHT_COL_PARAMETER_NAME,
"Sample Weight",
"Name of the field containing per-instance weights for training. \n"
+ "Higher weights result in the model prioritizing training on those samples. \n"
+ "Values must be non-negative. \n"
+ "If this field is selected, it is automatically dropped from the selected features.",
NOT_MANDATORY,
new FreeTextFieldType("")
)
);
}
Loading
Loading