|
19 | 19 | import time
|
20 | 20 |
|
21 | 21 | # 数据路径
|
22 |
| -data_dir = '/media/wsw/B634091A3408DF6D/data/kaggle/datasets/getting-started/digit-recognizer/' |
| 22 | +data_dir = '/Users/wuyanxue/Documents/GitHub/datasets/getting-started/digit-recognizer/' |
23 | 23 |
|
24 | 24 | # 加载数据
|
25 | 25 | def opencsv():
|
@@ -57,8 +57,8 @@ def dRPCA(data, COMPONENT_NUM=100):
|
57 | 57 | def trainModel(X_train, y_train):
|
58 | 58 | print('Train RF...')
|
59 | 59 | clf = RandomForestClassifier(
|
60 |
| - n_estimators=140, |
61 |
| - max_depth=20, |
| 60 | + n_estimators=10, |
| 61 | + max_depth=10, |
62 | 62 | min_samples_split=2,
|
63 | 63 | min_samples_leaf=1,
|
64 | 64 | random_state=34)
|
@@ -99,16 +99,13 @@ def getModel(filename):
|
99 | 99 | # 结果输出保存
|
100 | 100 | def saveResult(result, csvName):
|
101 | 101 | i = 0
|
102 |
| - fw = open(csvName, 'w') |
103 |
| - with open(os.path.join(data_dir, 'output/sample_submission.csv') |
104 |
| - ) as pred_file: |
| 102 | + n = len(result) |
| 103 | + print('the size of test set is {}'.format(n)) |
| 104 | + with open(os.path.join(data_dir, 'output/Result_sklearn_RF.csv'), 'w') as fw: |
105 | 105 | fw.write('{},{}\n'.format('ImageId', 'Label'))
|
106 |
| - for line in pred_file.readlines()[1:]: |
107 |
| - splits = line.strip().split(',') |
108 |
| - fw.write('{},{}\n'.format(splits[0], result[i])) |
109 |
| - i += 1 |
110 |
| - fw.close() |
111 |
| - print('Result saved successfully...') |
| 106 | + for i in range(1, n + 1): |
| 107 | + fw.write('{},{}\n'.format(i, result[i - 1])) |
| 108 | + print('Result saved successfully... and the path = {}'.format(csvName)) |
112 | 109 |
|
113 | 110 |
|
114 | 111 | def trainRF():
|
@@ -151,7 +148,7 @@ def preRF():
|
151 | 148 | result = clf.predict(pcaPreData)
|
152 | 149 |
|
153 | 150 | # 结果的输出
|
154 |
| - saveResult(result,os.path.join(data_dir, 'output/Result_sklearn_rf.csv')) |
| 151 | + saveResult(result, os.path.join(data_dir, 'output/Result_sklearn_rf.csv')) |
155 | 152 | print("finish!")
|
156 | 153 | stopTime = time.time()
|
157 | 154 | print('PreModel load time used:%f s' % (stopTime - startTime))
|
|
0 commit comments