1
1
from tornado import gen , web , locks
2
2
import traceback
3
3
import urllib .parse
4
-
5
4
from notebook .base .handlers import IPythonHandler
6
5
import threading
7
6
import json
11
10
12
11
from .pull import GitPuller
13
12
from .version import __version__
14
- from .hookspecs import handle_files
15
- from .plugins .zip_puller import ZipSourceGoogleDriveDownloader
16
- from .plugins .zip_puller import ZipSourceDropBoxDownloader
17
- from .plugins .zip_puller import ZipSourceWebDownloader
13
+ from . import hookspecs
18
14
import pluggy
15
+ import nbgitpuller
19
16
20
17
21
18
class SyncHandler (IPythonHandler ):
@@ -43,17 +40,38 @@ def emit(self, data):
43
40
self .write ('data: {}\n \n ' .format (serialized_data ))
44
41
yield self .flush ()
45
42
46
- def setup_plugins (self , repo ):
43
+ def setup_plugins (self , provider ):
47
44
pm = pluggy .PluginManager ("nbgitpuller" )
48
- pm .add_hookspecs (handle_files )
49
- if "drive.google.com" in repo :
50
- pm .register (ZipSourceGoogleDriveDownloader ())
51
- elif "dropbox.com" in repo :
52
- pm .register (ZipSourceDropBoxDownloader ())
53
- else :
54
- pm .register (ZipSourceWebDownloader ())
45
+ pm .add_hookspecs (hookspecs )
46
+ pm .load_setuptools_entrypoints ("nbgitpuller" , name = provider )
55
47
return pm
56
48
49
+ @gen .coroutine
50
+ def progress_loop (self , queue ):
51
+ while True :
52
+ try :
53
+ progress = queue .get_nowait ()
54
+ except Empty :
55
+ yield gen .sleep (0.1 )
56
+ continue
57
+ if progress is None :
58
+ yield gen .sleep (5 )
59
+ return
60
+ if isinstance (progress , Exception ):
61
+ self .emit ({
62
+ 'phase' : 'error' ,
63
+ 'message' : str (progress ),
64
+ 'output' : '\n ' .join ([
65
+ line .strip ()
66
+ for line in traceback .format_exception (
67
+ type (progress ), progress , progress .__traceback__
68
+ )
69
+ ])
70
+ })
71
+ return
72
+
73
+ self .emit ({'output' : progress , 'phase' : 'syncing' })
74
+
57
75
@web .authenticated
58
76
@gen .coroutine
59
77
def get (self ):
@@ -69,7 +87,7 @@ def get(self):
69
87
try :
70
88
repo = self .get_argument ('repo' )
71
89
branch = self .get_argument ('branch' , None )
72
- compressed = self .get_argument ('compressed ' , "false" )
90
+ provider = self .get_argument ('provider ' , None )
73
91
depth = self .get_argument ('depth' , None )
74
92
if depth :
75
93
depth = int (depth )
@@ -82,22 +100,31 @@ def get(self):
82
100
# so that all repos are always in scope after cloning. Sometimes
83
101
# server_root_dir will include things like `~` and so the path
84
102
# must be expanded.
85
- repo_parent_dir = os .path .join (os .path .expanduser (self .settings ['server_root_dir' ]),
86
- os .getenv ('NBGITPULLER_PARENTPATH' , '' ))
87
- repo_dir = os .path .join (repo_parent_dir , self .get_argument ('targetpath' , repo .split ('/' )[- 1 ]))
103
+ repo_parent_dir = os .path .join (os .path .expanduser (self .settings ['server_root_dir' ]), os .getenv ('NBGITPULLER_PARENTPATH' , '' ))
104
+ nbgitpuller .REPO_PARENT_DIR = repo_parent_dir
105
+
106
+ repo_dir = os .path .join (
107
+ repo_parent_dir ,
108
+ self .get_argument ('targetpath' , repo .split ('/' )[- 1 ]))
88
109
89
110
# We gonna send out event streams!
90
111
self .set_header ('content-type' , 'text/event-stream' )
91
112
self .set_header ('cache-control' , 'no-cache' )
92
113
93
- if compressed == 'true' :
94
- pm = self .setup_plugins (repo )
95
- results = pm .hook .handle_files (repo = repo , repo_parent_dir = repo_parent_dir )[0 ]
114
+ # if provider is specified then we are dealing with compressed
115
+ # archive and not a git repo
116
+ if provider is not None :
117
+ pm = self .setup_plugins (provider )
118
+ req_args = {k : v [0 ].decode () for k , v in self .request .arguments .items ()}
119
+ download_q = Queue ()
120
+ req_args ["progress_func" ] = lambda : self .progress_loop (download_q )
121
+ req_args ["download_q" ] = download_q
122
+ hf_args = {"query_line_args" : req_args }
123
+ results = pm .hook .handle_files (** hf_args )
96
124
repo_dir = repo_parent_dir + results ["unzip_dir" ]
97
125
repo = "file://" + results ["origin_repo_path" ]
98
126
99
127
gp = GitPuller (repo , repo_dir , branch = branch , depth = depth , parent = self .settings ['nbapp' ])
100
-
101
128
q = Queue ()
102
129
103
130
def pull ():
@@ -110,33 +137,11 @@ def pull():
110
137
q .put_nowait (e )
111
138
raise e
112
139
self .gp_thread = threading .Thread (target = pull )
113
-
114
140
self .gp_thread .start ()
115
-
116
- while True :
117
- try :
118
- progress = q .get_nowait ()
119
- except Empty :
120
- yield gen .sleep (0.5 )
121
- continue
122
- if progress is None :
123
- break
124
- if isinstance (progress , Exception ):
125
- self .emit ({
126
- 'phase' : 'error' ,
127
- 'message' : str (progress ),
128
- 'output' : '\n ' .join ([
129
- line .strip ()
130
- for line in traceback .format_exception (
131
- type (progress ), progress , progress .__traceback__
132
- )
133
- ])
134
- })
135
- return
136
-
137
- self .emit ({'output' : progress , 'phase' : 'syncing' })
138
-
141
+ self .progress_loop (q )
142
+ yield gen .sleep (3 )
139
143
self .emit ({'phase' : 'finished' })
144
+
140
145
except Exception as e :
141
146
self .emit ({
142
147
'phase' : 'error' ,
@@ -170,11 +175,10 @@ def initialize(self):
170
175
@gen .coroutine
171
176
def get (self ):
172
177
app_env = os .getenv ('NBGITPULLER_APP' , default = 'notebook' )
173
-
174
178
repo = self .get_argument ('repo' )
175
179
branch = self .get_argument ('branch' , None )
176
180
depth = self .get_argument ('depth' , None )
177
- compressed = self .get_argument ('compressed ' , "false" )
181
+ provider = self .get_argument ('provider ' , None )
178
182
urlPath = self .get_argument ('urlpath' , None ) or \
179
183
self .get_argument ('urlPath' , None )
180
184
subPath = self .get_argument ('subpath' , None ) or \
@@ -195,14 +199,17 @@ def get(self):
195
199
else :
196
200
path = 'tree/' + path
197
201
202
+ if provider is not None :
203
+ path = "tree/"
204
+
198
205
self .write (
199
206
self .render_template (
200
207
'status.html' ,
201
208
repo = repo ,
202
209
branch = branch ,
203
- compressed = compressed ,
204
210
path = path ,
205
211
depth = depth ,
212
+ provider = provider ,
206
213
targetpath = targetpath ,
207
214
version = __version__
208
215
))
@@ -239,3 +246,10 @@ def get(self):
239
246
)
240
247
241
248
self .redirect (new_url )
249
+
250
+
251
+ class ThreadWithResult (threading .Thread ):
252
+ def __init__ (self , group = None , target = None , name = None , args = (), kwargs = {}, * , daemon = None ):
253
+ def function ():
254
+ self .result = target (* args , ** kwargs )
255
+ super ().__init__ (group = group , target = function , name = name , daemon = daemon )
0 commit comments