Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.
Prev Previous commit
Next Next commit
bug fixes and tests fixes
  • Loading branch information
kossak committed May 21, 2019
commit 0d987db41a0b9eb088763b1a0f81290f1e7cdeb0
10 changes: 7 additions & 3 deletions paperspace/commands/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import terminaltables
from click import style

from paperspace import config, client
from paperspace.commands import common
from paperspace.exceptions import BadResponseError
from paperspace.utils import get_terminal_lines
from paperspace.workspace import S3WorkspaceHandler, WorkspaceHandler
from paperspace.workspace import WorkspaceHandler


class JobsCommandBase(common.CommandBase):
Expand Down Expand Up @@ -137,7 +136,7 @@ def execute(self, json_):
archive_basename = self._workspace_handler.archive_basename
json_["workspaceFileName"] = archive_basename
self.api.headers["Content-Type"] = "multipart/form-data"
files = {"file": open(workspace_url, "rb")}
files = self._get_files_dict(workspace_url)
else:
json_["workspaceFileName"] = workspace_url

Expand All @@ -148,6 +147,11 @@ def execute(self, json_):
"Job created",
"Unknown error while creating job")

@staticmethod
def _get_files_dict(workspace_url):
files = {"file": open(workspace_url, "rb")}
return files

@staticmethod
def set_project(json_):
project_id = json_.get("projectId", json_.get("projectHandle"))
Expand Down
9 changes: 8 additions & 1 deletion paperspace/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,19 @@ def _clear_script_name(script_name, mode):
return script_name

def _create_command(self, mode, script, python_version=None):
command_parts = []
executor = self._get_executor(mode, python_version)
if executor:
command_parts.append(executor)

script_name = self._clear_script_name(script[0], mode)
command_parts = [executor, script_name]
if script_name:
command_parts.append(script_name)

script_params = ' '.join(script[1:])
if script_params:
command_parts.append(script_params)

command = ' '.join(command_parts)
return command

Expand Down
13 changes: 7 additions & 6 deletions paperspace/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ def handle(self, input_data):
ignore_files = input_data.get('ignore_files')

if workspace_url:
return # nothing to do
return workspace_url # nothing to do

# Should be removed as soon it won't be necessary by PS_API
if workspace_path == 'none':
return 'none'
return workspace_path

if workspace_archive:
archive_path = os.path.abspath(workspace_archive)
else:
Expand Down Expand Up @@ -142,10 +143,10 @@ def __init__(self, experiments_api, logger=None):
self.experiments_api = experiments_api

def handle(self, input_data):
archive_path = super(S3WorkspaceHandler, self).handle(input_data)
if archive_path in ['none', None]:
return archive_path

workspace = super(S3WorkspaceHandler, self).handle(input_data)
if not self.archive_path:
return workspace
archive_path = workspace
file_name = os.path.basename(archive_path)
project_handle = input_data['projectHandle']

Expand Down
36 changes: 30 additions & 6 deletions tests/functional/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ class TestRunCommand(object):

@mock.patch("paperspace.client.requests.post")
@mock.patch("paperspace.workspace.WorkspaceHandler._zip_workspace")
def test_run_simple_file_with_args(self, workspace_zip_patched, post_patched):
workspace_zip_patched.return_value = '/dev/random'
@mock.patch("paperspace.commands.jobs.CreateJobCommand._get_files_dict")
def test_run_simple_file_with_args(self, get_files_patched, workspace_zip_patched, post_patched):
get_files_patched.return_value = mock.MagicMock()
workspace_zip_patched.return_value = '/foo/bar'
post_patched.return_value = MockResponse(status_code=200)

runner = CliRunner()
Expand All @@ -28,7 +30,7 @@ def test_run_simple_file_with_args(self, workspace_zip_patched, post_patched):
})
post_patched.assert_called_with(self.url,
params={'name': u'test', 'projectId': u'projectId',
'workspaceFileName': 'random',
'workspaceFileName': 'bar',
'command': 'python2 myscript.py a b',
'projectHandle': u'projectId',
'container': u'paperspace/tensorflow-python'},
Expand All @@ -38,9 +40,7 @@ def test_run_simple_file_with_args(self, workspace_zip_patched, post_patched):
json=None)

@mock.patch("paperspace.client.requests.post")
@mock.patch("paperspace.workspace.WorkspaceHandler._zip_workspace")
def test_run_python_command_with_args_and_no_workspace(self, workspace_zip_patched, post_patched):
workspace_zip_patched.return_value = '/dev/random'
def test_run_python_command_with_args_and_no_workspace(self, post_patched):
post_patched.return_value = MockResponse(status_code=200)

runner = CliRunner()
Expand All @@ -59,3 +59,27 @@ def test_run_python_command_with_args_and_no_workspace(self, workspace_zip_patch
files=None,
headers=expected_headers,
json=None)

@mock.patch("paperspace.client.requests.post")
@mock.patch("paperspace.workspace.WorkspaceHandler._zip_workspace")
def test_run_shell_command_with_args_with_s3_workspace(self, workspace_zip_patched, post_patched):
workspace_zip_patched.return_value = '/foo/bar'
post_patched.return_value = MockResponse(status_code=200)

runner = CliRunner()
result = runner.invoke(cli.cli,
[self.command_name] + self.common_commands + ["-s", "echo foo", "--workspaceUrl",
"s3://bucket/object"])

expected_headers = self.headers.copy()
post_patched.assert_called_with(self.url,
params={'name': u'test', 'projectId': u'projectId',
'workspaceFileName': 's3://bucket/object',
'workspaceUrl': 's3://bucket/object',
'command': 'echo foo',
'projectHandle': u'projectId',
'container': u'paperspace/tensorflow-python'},
data=None,
files=None,
headers=expected_headers,
json=None)