Skip to content

Commit 9fe6b17

Browse files
committed
Fix mypy in test_paramiko_platform.py
1 parent 8846a8e commit 9fe6b17

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

autosubmit/platforms/paramiko_platform.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, expid: str, name: str, config: dict, auth_password: Optional[
9898
self._host_config_id = None
9999
self.submit_cmd = ""
100100
self._ftpChannel: Optional[paramiko.SFTPClient] = None
101-
self.transport = None
101+
self.transport: Optional[paramiko.Transport] = None
102102
self.channels = {}
103103
if sys.platform != "linux":
104104
self.poller = select.kqueue()
@@ -1159,8 +1159,7 @@ def exec_command(
11591159
self.restore_connection(None)
11601160
timeout = timeout + 60
11611161
retries = retries - 1
1162-
if retries <= 0:
1163-
return False, False, False
1162+
return False, False, False
11641163

11651164
def send_command_non_blocking(self, command, ignore_log):
11661165
thread = threading.Thread(target=self.send_command, args=(command, ignore_log))

autosubmit/platforms/paramiko_submitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class ParamikoSubmitter:
110110

111111
def __init__(self, as_conf: 'AutosubmitConfig', auth_password: Optional[str] = None,
112112
local_auth_password=None):
113-
self.platforms = None
113+
self.platforms: Optional[dict[str, 'ParamikoPlatform']] = None
114114
self.load_platforms(as_conf=as_conf, auth_password=auth_password, local_auth_password=local_auth_password)
115115

116116
def load_local_platform(self, as_conf: 'AutosubmitConfig', experiment_data: Optional[dict] = None,

test/integration/test_paramiko_platform.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import cast, Generator, Optional, Protocol, Union, TYPE_CHECKING
2525

2626
import pytest
27+
from paramiko import ChannelFile # type: ignore[import]
2728

2829
from autosubmit.job.job import Job
2930
from autosubmit.job.job_common import Status
@@ -107,7 +108,8 @@ def exp_platform_server(autosubmit_exp, ssh_server, request) -> ExperimentPlatfo
107108
# to maintain and test).
108109
submitter = ParamikoSubmitter(as_conf=exp.as_conf)
109110

110-
ps_platform: 'PsPlatform' = submitter.platforms[_PLATFORM_NAME]
111+
assert submitter.platforms
112+
ps_platform: 'PsPlatform' = cast('PsPlatform', submitter.platforms[_PLATFORM_NAME])
111113

112114
return ExperimentPlatformServer(exp, ps_platform, ssh_server)
113115

@@ -123,14 +125,15 @@ class CreateJobParametersPlatformFixture(Protocol):
123125

124126
def __call__(
125127
self,
126-
experiment_data: Optional[dict] = None
128+
experiment_data: Optional[dict] = None,
129+
/
127130
) -> JobParametersPlatform:
128131
...
129132

130133

131134
@pytest.fixture
132135
def create_job_parameters_platform(autosubmit_exp) -> CreateJobParametersPlatformFixture:
133-
def job_parameters_platform(experiment_data: dict) -> JobParametersPlatform:
136+
def job_parameters_platform(experiment_data: Optional[dict] = None) -> JobParametersPlatform:
134137
exp = autosubmit_exp(_EXPID, experiment_data=experiment_data)
135138
slurm_platform: 'SlurmPlatform' = cast('SlurmPlatform', exp.platform)
136139

@@ -237,8 +240,8 @@ def test_send_file_errors(exp_platform_server: ExperimentPlatformServer):
237240
]
238241
)
239242
@pytest.mark.docker
240-
def test_send_command(cmd: str, error: Optional, x11_enabled: bool, mfa_enabled: bool, request: pytest.FixtureRequest,
241-
mocker):
243+
def test_send_command(cmd: str, error: Optional[Exception], x11_enabled: bool, mfa_enabled: bool,
244+
request: pytest.FixtureRequest, mocker):
242245
"""This test opens an SSH connection (via sftp) and sends a command."""
243246
if x11_enabled:
244247
request.applymarker('x11')
@@ -258,7 +261,7 @@ def test_send_command(cmd: str, error: Optional, x11_enabled: bool, mfa_enabled:
258261

259262
if error:
260263
assert exp_platform_server.platform.get_ssh_output_err() == ''
261-
with pytest.raises(error):
264+
with pytest.raises(error): # type: ignore
262265
exp_platform_server.platform.send_command(cmd, ignore_log=False, x11=x11_enabled)
263266

264267
stderr = exp_platform_server.platform.get_ssh_output_err()
@@ -282,6 +285,7 @@ def test_exec_command(exp_platform_server: 'ExperimentPlatformServer'):
282285
assert stdin is not False
283286
assert stderr is not False
284287
# The stdout contents should be [b"user_name\n"]; thus the ugly list comprehension + extra code.
288+
assert isinstance(stdout, ChannelFile)
285289
assert user == str(''.join([x.decode('UTF-8').strip() for x in stdout.readlines()]))
286290

287291

@@ -310,6 +314,7 @@ def test_exec_command_invalid_command(command: str, expected: str, x11: bool, re
310314
exp_platform_server.platform.connect(None, reconnect=False, log_recovery_process=False)
311315

312316
stdin, stdout, stderr = exp_platform_server.platform.exec_command(command, x11=x11)
317+
assert isinstance(stdout, ChannelFile)
313318
assert stdin is not False
314319
assert stderr is not False
315320
# The stdout contents should be [b"user_name\n"]; thus the ugly list comprehension + extra code.
@@ -327,6 +332,7 @@ def test_exec_command_after_a_reset(exp_platform_server: 'ExperimentPlatformServ
327332
exp_platform_server.platform.connect(None, reconnect=False, log_recovery_process=False)
328333

329334
stdin, stdout, stderr = exp_platform_server.platform.exec_command('whoami')
335+
assert isinstance(stdout, ChannelFile)
330336
assert stdin is not False
331337
assert stderr is not False
332338
# The stdout contents should be [b"user_name\n"]; thus the ugly list comprehension + extra code.
@@ -358,6 +364,7 @@ def test_exec_command_ssh_session_not_active(x11: bool, retries: int, command: s
358364
# But while that's OK, we can also avoid mocking by simply
359365
# closing the connection.
360366

367+
assert exp_platform_server.platform.transport
361368
exp_platform_server.platform.transport.close()
362369

363370
stdin, stdout, stderr = exp_platform_server.platform.exec_command(
@@ -367,6 +374,7 @@ def test_exec_command_ssh_session_not_active(x11: bool, retries: int, command: s
367374
)
368375

369376
# This will be true iff the ``ps_platform.restore_connection(None)`` ran without errors.
377+
assert isinstance(stdout, ChannelFile)
370378
assert stdin is not False
371379
assert stderr is not False
372380
# The stdout contents should be [b"user_name\n"]; thus the ugly list comprehension + extra code.
@@ -411,11 +419,14 @@ def test_fs_operations(exp_platform_server: 'ExperimentPlatformServer', request)
411419
assert exp_platform_server.platform.send_file(local_file.name)
412420

413421
contents = exp_platform_server.platform.read_file(str(remote_file))
422+
assert contents
414423
assert contents.decode('UTF-8').strip() == text
415-
assert None is exp_platform_server.platform.read_file(str(file_not_found))
424+
assert exp_platform_server.platform.read_file(str(file_not_found)) is None
416425

417-
assert exp_platform_server.platform.get_file_size(str(remote_file)) > 0
418-
assert None is exp_platform_server.platform.get_file_size(str(file_not_found))
426+
file_size: Optional[int] = exp_platform_server.platform.get_file_size(str(remote_file))
427+
assert file_size
428+
assert file_size > 0
429+
assert exp_platform_server.platform.get_file_size(str(file_not_found)) is None
419430

420431
assert exp_platform_server.platform.check_absolute_file_exists(str(remote_file))
421432
assert not exp_platform_server.platform.check_absolute_file_exists(str(file_not_found))
@@ -463,6 +474,7 @@ def test_exec_command_with_x11(x11_enabled: bool, user_or_false: Union[str, bool
463474
if type(user_or_false) is bool:
464475
assert user_or_false == stdout
465476
else:
477+
assert isinstance(stdout, ChannelFile)
466478
assert user_or_false == stdout.readline().decode('UTF-8').strip()
467479

468480

@@ -483,6 +495,9 @@ def test_xclock(exp_platform_server: ExperimentPlatformServer):
483495

484496
_, stdout, stderr = ps_platform.exec_command('timeout 1 xclock', x11=True)
485497

498+
assert isinstance(stdout, ChannelFile)
499+
assert isinstance(stderr, ChannelFile)
500+
486501
assert ''.join(stdout.readlines()) == ''
487502
assert ''.join(stderr.readlines()) == ''
488503

0 commit comments

Comments
 (0)