Skip to content

Commit c801ad5

Browse files
update
1 parent 09d2e2e commit c801ad5

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

examples/log_reg_binary_dense_batch.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,29 @@
3131

3232

3333
def main():
34+
nClasses = 2
35+
nFeatures = 20
36+
3437
# read training data from file with 20 features per observation and 1 class label
3538
trainfile = "./data/batch/binary_cls_train.csv"
36-
train_data = read_csv(trainfile, range(20))
37-
train_dep_data = read_csv(trainfile, range(20, 21))
38-
nVectors = train_data.shape[0]
39-
train_dep_data.shape = (nVectors, 1) # must be a 2d array
39+
train_data = read_csv(trainfile, range(nFeatures))
40+
train_labels = read_csv(trainfile, range(nFeatures, nFeatures + 1))
41+
train_labels.shape = (train_data.shape[0], 1) # must be a 2d array
4042

4143
# set parameters and train
42-
train_alg = d4p.logistic_regression_training(nClasses=2)
43-
train_result = train_alg.compute(train_data, train_dep_data)
44+
train_alg = d4p.logistic_regression_training(nClasses=nClasses)
45+
train_result = train_alg.compute(train_data, train_labels)
4446

4547
# read testing data from file with 20 features per observation
4648
testfile = "./data/batch/binary_cls_test.csv"
47-
predict_data = read_csv(testfile, range(20))
49+
predict_data = read_csv(testfile, range(nFeatures))
4850

4951
# set parameters and compute predictions
50-
predict_alg = d4p.logistic_regression_prediction(nClasses=2)
52+
predict_alg = d4p.logistic_regression_prediction(nClasses=nClasses)
5153
predict_result = predict_alg.compute(predict_data, train_result.model)
5254

5355
# the prediction result provides prediction
54-
assert predict_result.prediction.shape == (predict_data.shape[0], train_dep_data.shape[1])
56+
assert predict_result.prediction.shape == (predict_data.shape[0], train_labels.shape[1])
5557

5658

5759
if __name__ == "__main__":

examples/log_reg_dense_batch.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,32 +31,34 @@
3131

3232

3333
def main():
34+
nClasses = 5
35+
nFeatures = 6
36+
3437
# read training data from file with 6 features per observation and 1 class label
3538
trainfile = "./data/batch/logreg_train.csv"
36-
train_data = read_csv(trainfile, range(6))
37-
train_dep_data = read_csv(trainfile, range(6, 7))
38-
nVectors = train_data.shape[0]
39-
train_dep_data.shape = (nVectors, 1) # must be a 2d array
39+
train_data = read_csv(trainfile, range(nFeatures))
40+
train_labels = read_csv(trainfile, range(nFeatures, nFeatures + 1))
41+
train_labels.shape = (train_data.shape[0], 1) # must be a 2d array
4042

4143
# set parameters and train
42-
train_alg = d4p.logistic_regression_training(nClasses=5,
44+
train_alg = d4p.logistic_regression_training(nClasses=nClasses,
4345
penaltyL1=0.1,
4446
penaltyL2=0.1)
45-
train_result = train_alg.compute(train_data, train_dep_data)
47+
train_result = train_alg.compute(train_data, train_labels)
4648

4749
# read testing data from file with 6 features per observation
4850
testfile = "./data/batch/logreg_test.csv"
49-
predict_data = read_csv(testfile, range(6))
51+
predict_data = read_csv(testfile, range(nFeatures))
5052

5153
# set parameters and compute predictions
52-
predict_alg = d4p.logistic_regression_prediction(nClasses=5,
54+
predict_alg = d4p.logistic_regression_prediction(nClasses=nClasses,
5355
resultsToCompute="computeClassesLabels|computeClassesProbabilities|computeClassesLogProbabilities")
5456
predict_result = predict_alg.compute(predict_data, train_result.model)
5557

56-
# the prediction result provides prediction, classes probabilities and classes log probabilities
57-
assert predict_result.prediction.shape == (predict_data.shape[0], train_dep_data.shape[1]) \
58-
and predict_result.probabilities.shape == (predict_data.shape[0], 5) \
59-
and predict_result.logProbabilities.shape == (predict_data.shape[0], 5)
58+
# the prediction result provides prediction, probabilities and logProbabilities
59+
assert predict_result.prediction.shape == (predict_data.shape[0], train_labels.shape[1])
60+
assert predict_result.probabilities.shape == (predict_data.shape[0], nClasses)
61+
assert predict_result.logProbabilities.shape == (predict_data.shape[0], nClasses)
6062

6163

6264
if __name__ == "__main__":

0 commit comments

Comments
 (0)