Skip to content

Commit 5866fd7

Browse files
authored
Add type check in free and change Exception to TypeError (ray-project#4221)
1 parent e96e06e commit 5866fd7

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

python/ray/internal/internal_api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def free(object_ids, local_only=False):
3535
raise TypeError("free() expects a list of ObjectID, got {}".format(
3636
type(object_ids)))
3737

38+
# Make sure that the values are object IDs.
39+
for object_id in object_ids:
40+
if not isinstance(object_id, ray.ObjectID):
41+
raise TypeError("Attempting to call `free` on the value {}, "
42+
"which is not an ray.ObjectID.".format(object_id))
43+
3844
worker.check_connected()
3945
with profiling.profile("ray.free"):
4046
if len(object_ids) == 0:

python/ray/worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def put_object(self, object_id, value):
347347
"""
348348
# Make sure that the value is not an object ID.
349349
if isinstance(value, ObjectID):
350-
raise Exception(
350+
raise TypeError(
351351
"Calling 'put' on an ray.ObjectID is not allowed "
352352
"(similarly, returning an ray.ObjectID from a remote "
353353
"function is not allowed). If you really want to "
@@ -470,7 +470,7 @@ def get_object(self, object_ids):
470470
# Make sure that the values are object IDs.
471471
for object_id in object_ids:
472472
if not isinstance(object_id, ObjectID):
473-
raise Exception(
473+
raise TypeError(
474474
"Attempting to call `get` on the value {}, "
475475
"which is not an ray.ObjectID.".format(object_id))
476476
# Do an initial fetch for remote objects. We divide the fetch into
@@ -1800,7 +1800,7 @@ def connect(info,
18001800
driver_id = DriverID(_random_string())
18011801

18021802
if not isinstance(driver_id, DriverID):
1803-
raise Exception("The type of given driver id must be DriverID.")
1803+
raise TypeError("The type of given driver id must be DriverID.")
18041804

18051805
worker.worker_id = driver_id.binary()
18061806

0 commit comments

Comments
 (0)