@@ -1197,18 +1197,26 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
1197
1197
1198
1198
var modelWrapError func (error ) error
1199
1199
var modelResources []userconfig.ModelResource
1200
+ var modelFileResources []userconfig.ModelResource
1200
1201
1201
1202
if hasSingleModel {
1202
1203
modelWrapError = func (err error ) error {
1203
- return errors .Wrap (err , userconfig .ModelsPathKey )
1204
+ return errors .Wrap (err , userconfig .ModelsKey , userconfig . ModelsPathKey )
1204
1205
}
1205
- modelResources = []userconfig.ModelResource {
1206
- {
1207
- Name : consts .SingleModelName ,
1208
- Path : * predictor .Models .Path ,
1209
- },
1206
+ modelResource := userconfig.ModelResource {
1207
+ Name : consts .SingleModelName ,
1208
+ Path : * predictor .Models .Path ,
1209
+ }
1210
+
1211
+ if strings .HasSuffix (* predictor .Models .Path , ".onnx" ) && provider != types .LocalProviderType {
1212
+ if err := validateONNXModelFilePath (* predictor .Models .Path , projectFiles .ProjectDir (), awsClient , gcpClient ); err != nil {
1213
+ return modelWrapError (err )
1214
+ }
1215
+ modelFileResources = append (modelFileResources , modelResource )
1216
+ } else {
1217
+ modelResources = append (modelResources , modelResource )
1218
+ * predictor .Models .Path = s .EnsureSuffix (* predictor .Models .Path , "/" )
1210
1219
}
1211
- * predictor .Models .Path = s .EnsureSuffix (* predictor .Models .Path , "/" )
1212
1220
}
1213
1221
if hasMultiModels {
1214
1222
if len (predictor .Models .Paths ) > 0 {
@@ -1225,8 +1233,15 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
1225
1233
path .Name ,
1226
1234
)
1227
1235
}
1228
- (* path ).Path = s .EnsureSuffix ((* path ).Path , "/" )
1229
- modelResources = append (modelResources , * path )
1236
+ if strings .HasSuffix ((* path ).Path , ".onnx" ) && provider != types .LocalProviderType {
1237
+ if err := validateONNXModelFilePath ((* path ).Path , projectFiles .ProjectDir (), awsClient , gcpClient ); err != nil {
1238
+ return errors .Wrap (modelWrapError (err ), path .Name )
1239
+ }
1240
+ modelFileResources = append (modelFileResources , * path )
1241
+ } else {
1242
+ (* path ).Path = s .EnsureSuffix ((* path ).Path , "/" )
1243
+ modelResources = append (modelResources , * path )
1244
+ }
1230
1245
}
1231
1246
}
1232
1247
@@ -1249,6 +1264,23 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
1249
1264
return modelWrapError (err )
1250
1265
}
1251
1266
1267
+ for _ , modelFileResource := range modelFileResources {
1268
+ s3Path := strings .HasPrefix (modelFileResource .Path , "s3://" )
1269
+ gcsPath := strings .HasPrefix (modelFileResource .Path , "gs://" )
1270
+ localPath := ! s3Path && ! gcsPath
1271
+
1272
+ * models = append (* models , CuratedModelResource {
1273
+ ModelResource : & userconfig.ModelResource {
1274
+ Name : modelFileResource .Name ,
1275
+ Path : modelFileResource .Path ,
1276
+ },
1277
+ S3Path : s3Path ,
1278
+ GCSPath : gcsPath ,
1279
+ LocalPath : localPath ,
1280
+ IsFilePath : true ,
1281
+ })
1282
+ }
1283
+
1252
1284
if hasMultiModels {
1253
1285
for _ , model := range * models {
1254
1286
if model .Name == consts .SingleModelName {
@@ -1264,6 +1296,58 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
1264
1296
return nil
1265
1297
}
1266
1298
1299
+ func validateONNXModelFilePath (modelPath string , projectDir string , awsClient * aws.Client , gcpClient * gcp.Client ) error {
1300
+ s3Path := strings .HasPrefix (modelPath , "s3://" )
1301
+ gcsPath := strings .HasPrefix (modelPath , "gs://" )
1302
+ localPath := ! s3Path && ! gcsPath
1303
+
1304
+ if s3Path {
1305
+ awsClientForBucket , err := aws .NewFromClientS3Path (modelPath , awsClient )
1306
+ if err != nil {
1307
+ return err
1308
+ }
1309
+
1310
+ bucket , modelPrefix , err := aws .SplitS3Path (modelPath )
1311
+ if err != nil {
1312
+ return err
1313
+ }
1314
+
1315
+ isS3File , err := awsClientForBucket .IsS3File (bucket , modelPrefix )
1316
+ if err != nil {
1317
+ return err
1318
+ }
1319
+
1320
+ if ! isS3File {
1321
+ return ErrorInvalidONNXModelFilePath (modelPrefix )
1322
+ }
1323
+ }
1324
+
1325
+ if gcsPath {
1326
+ bucket , modelPrefix , err := gcp .SplitGCSPath (modelPath )
1327
+ if err != nil {
1328
+ return err
1329
+ }
1330
+
1331
+ isGCSFile , err := gcpClient .IsGCSFile (bucket , modelPrefix )
1332
+ if err != nil {
1333
+ return err
1334
+ }
1335
+
1336
+ if ! isGCSFile {
1337
+ return ErrorInvalidONNXModelFilePath (modelPrefix )
1338
+ }
1339
+ }
1340
+
1341
+ if localPath {
1342
+ expandedLocalPath := files .RelToAbsPath (modelPath , projectDir )
1343
+ if err := files .CheckFile (expandedLocalPath ); err != nil {
1344
+ return err
1345
+ }
1346
+ }
1347
+
1348
+ return nil
1349
+ }
1350
+
1267
1351
func validatePythonPath (predictor * userconfig.Predictor , projectFiles ProjectFiles ) error {
1268
1352
if ! projectFiles .HasDir (* predictor .PythonPath ) {
1269
1353
return ErrorPythonPathNotFound (* predictor .PythonPath )
0 commit comments