Skip to content

Commit dcf77d1

Browse files
修复rf-python3.6.py中文件打开没有关闭的问题
1 parent 97eabca commit dcf77d1

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

src/python/getting-started/digit-recognizer/rf-python3.6.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import time
2020

2121
# 数据路径
22-
data_dir = '/media/wsw/B634091A3408DF6D/data/kaggle/datasets/getting-started/digit-recognizer/'
22+
data_dir = '/Users/wuyanxue/Documents/GitHub/datasets/getting-started/digit-recognizer/'
2323

2424
# 加载数据
2525
def opencsv():
@@ -57,8 +57,8 @@ def dRPCA(data, COMPONENT_NUM=100):
5757
def trainModel(X_train, y_train):
5858
print('Train RF...')
5959
clf = RandomForestClassifier(
60-
n_estimators=140,
61-
max_depth=20,
60+
n_estimators=10,
61+
max_depth=10,
6262
min_samples_split=2,
6363
min_samples_leaf=1,
6464
random_state=34)
@@ -99,16 +99,13 @@ def getModel(filename):
9999
# 结果输出保存
100100
def saveResult(result, csvName):
101101
i = 0
102-
fw = open(csvName, 'w')
103-
with open(os.path.join(data_dir, 'output/sample_submission.csv')
104-
) as pred_file:
102+
n = len(result)
103+
print('the size of test set is {}'.format(n))
104+
with open(os.path.join(data_dir, 'output/Result_sklearn_RF.csv'), 'w') as fw:
105105
fw.write('{},{}\n'.format('ImageId', 'Label'))
106-
for line in pred_file.readlines()[1:]:
107-
splits = line.strip().split(',')
108-
fw.write('{},{}\n'.format(splits[0], result[i]))
109-
i += 1
110-
fw.close()
111-
print('Result saved successfully...')
106+
for i in range(1, n + 1):
107+
fw.write('{},{}\n'.format(i, result[i - 1]))
108+
print('Result saved successfully... and the path = {}'.format(csvName))
112109

113110

114111
def trainRF():
@@ -151,7 +148,7 @@ def preRF():
151148
result = clf.predict(pcaPreData)
152149

153150
# 结果的输出
154-
saveResult(result,os.path.join(data_dir, 'output/Result_sklearn_rf.csv'))
151+
saveResult(result, os.path.join(data_dir, 'output/Result_sklearn_rf.csv'))
155152
print("finish!")
156153
stopTime = time.time()
157154
print('PreModel load time used:%f s' % (stopTime - startTime))

0 commit comments

Comments
 (0)