Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 65 additions & 49 deletions internal/controller/appwrapper_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,74 +37,90 @@ type AppWrapperWebhook struct {

var _ webhook.CustomDefaulter = &AppWrapperWebhook{}

// Default implements webhook.CustomDefaulter so a webhook will be registered for the type
// Default ensures that Suspend is set appropriately when an AppWrapper is created
func (w *AppWrapperWebhook) Default(ctx context.Context, obj runtime.Object) error {
job := obj.(*workloadv1beta2.AppWrapper)
log.FromContext(ctx).Info("Applying defaults", "job", job)
jobframework.ApplyDefaultForSuspend((*AppWrapper)(job), w.ManageJobsWithoutQueueName)
aw := obj.(*workloadv1beta2.AppWrapper)
log.FromContext(ctx).Info("Applying defaults", "job", aw)
jobframework.ApplyDefaultForSuspend((*AppWrapper)(aw), w.ManageJobsWithoutQueueName)
return nil
}

//+kubebuilder:webhook:path=/validate-workload-codeflare-dev-v1beta2-appwrapper,mutating=false,failurePolicy=fail,sideEffects=None,groups=workload.codeflare.dev,resources=appwrappers,verbs=create;update,versions=v1beta2,name=vappwrapper.kb.io,admissionReviewVersions=v1

var _ webhook.CustomValidator = &AppWrapperWebhook{}

// ValidateCreate implements webhook.CustomValidator so a webhook will be registered for the type
// ValidateCreate validates invariants when an AppWrapper is created
func (w *AppWrapperWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
job := obj.(*workloadv1beta2.AppWrapper)
log.FromContext(ctx).Info("Validating create", "job", job)
return nil, w.validateCreate(job).ToAggregate()
}
aw := obj.(*workloadv1beta2.AppWrapper)
log.FromContext(ctx).Info("Validating create", "job", aw)

func (w *AppWrapperWebhook) validateCreate(job *workloadv1beta2.AppWrapper) field.ErrorList {
var allErrors field.ErrorList
allErrors := w.validateAppWrapperInvariants(ctx, aw)

if w.ManageJobsWithoutQueueName || jobframework.QueueName((*AppWrapper)(job)) != "" {
components := job.Spec.Components
componentsPath := field.NewPath("spec").Child("components")
podSpecCount := 0
for idx, component := range components {
podSetsPath := componentsPath.Index(idx).Child("podSets")
for psIdx, ps := range component.PodSets {
podSetPath := podSetsPath.Index(psIdx)
if ps.Path == "" {
allErrors = append(allErrors, field.Required(podSetPath.Child("path"), "podspec must specify path"))
}
if w.ManageJobsWithoutQueueName || jobframework.QueueName((*AppWrapper)(aw)) != "" {
allErrors = append(allErrors, jobframework.ValidateCreateForQueueName((*AppWrapper)(aw))...)
}

// TODO: Validatate the ps.Path resolves to a PodSpec
return nil, allErrors.ToAggregate()
}

// TODO: RBAC check to make sure that the user has the ability to create the wrapped resources
// ValidateUpdate validates invariants when an AppWrapper is updated
func (w *AppWrapperWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
oldAW := oldObj.(*workloadv1beta2.AppWrapper)
newAW := newObj.(*workloadv1beta2.AppWrapper)
log.FromContext(ctx).Info("Validating update", "job", newAW)

podSpecCount += 1
}
}
if podSpecCount == 0 {
allErrors = append(allErrors, field.Invalid(componentsPath, components, "components contains no podspecs"))
}
if podSpecCount > 8 {
allErrors = append(allErrors, field.Invalid(componentsPath, components, fmt.Sprintf("components contains %v podspecs; at most 8 are allowed", podSpecCount)))
}
allErrors := w.validateAppWrapperInvariants(ctx, newAW)

if w.ManageJobsWithoutQueueName || jobframework.QueueName((*AppWrapper)(newAW)) != "" {
allErrors = append(allErrors, jobframework.ValidateUpdateForQueueName((*AppWrapper)(oldAW), (*AppWrapper)(newAW))...)
allErrors = append(allErrors, jobframework.ValidateUpdateForWorkloadPriorityClassName((*AppWrapper)(oldAW), (*AppWrapper)(newAW))...)
}

allErrors = append(allErrors, jobframework.ValidateCreateForQueueName((*AppWrapper)(job))...)
return allErrors
return nil, allErrors.ToAggregate()
}

// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *AppWrapperWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
oldJob := oldObj.(*workloadv1beta2.AppWrapper)
newJob := newObj.(*workloadv1beta2.AppWrapper)
if w.ManageJobsWithoutQueueName || jobframework.QueueName((*AppWrapper)(newJob)) != "" {
log.FromContext(ctx).Info("Validating update", "job", newJob)
allErrors := jobframework.ValidateUpdateForQueueName((*AppWrapper)(oldJob), (*AppWrapper)(newJob))
allErrors = append(allErrors, w.validateCreate(newJob)...)
allErrors = append(allErrors, jobframework.ValidateUpdateForWorkloadPriorityClassName((*AppWrapper)(oldJob), (*AppWrapper)(newJob))...)
return nil, allErrors.ToAggregate()
}
// ValidateDelete is a noop for us, but is required to implement the CustomValidator interface
func (w *AppWrapperWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) {
return nil, nil
}

// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type
func (w *AppWrapperWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
return nil, nil
// validateAppWrapperInvariants checks AppWrapper-specific invariants
func (w *AppWrapperWebhook) validateAppWrapperInvariants(_ context.Context, aw *workloadv1beta2.AppWrapper) field.ErrorList {
allErrors := field.ErrorList{}
components := aw.Spec.Components
componentsPath := field.NewPath("spec").Child("components")
podSpecCount := 0

for idx, component := range components {

// Each PodSet.Path must specify a path within Template to a v1.PodSpecTemplate
podSetsPath := componentsPath.Index(idx).Child("podSets")
for psIdx, ps := range component.PodSets {
podSetPath := podSetsPath.Index(psIdx)
if ps.Path == "" {
allErrors = append(allErrors, field.Required(podSetPath.Child("path"), "podspec must specify path"))
}
if _, err := getPodTemplateSpec(component.Template.Raw, ps.Path); err != nil {
allErrors = append(allErrors, field.Invalid(podSetPath.Child("path"), ps.Path,
fmt.Sprintf("path does not refer to a v1.PodSpecTemplate: %v", err)))
}
podSpecCount += 1
}

// TODO: RBAC check to make sure that the user has permissions to create the component

// TODO: We could attempt to validate the object is namespaced and the namespace is the same as the AppWrapper's namespace
// This is currently enforced when the resources are created.

}

// Enforce Kueue limitation that 0 < podSpecCount <= 8
if podSpecCount == 0 {
allErrors = append(allErrors, field.Invalid(componentsPath, components, "components contains no podspecs"))
}
if podSpecCount > 8 {
allErrors = append(allErrors, field.Invalid(componentsPath, components, fmt.Sprintf("components contains %v podspecs; at most 8 are allowed", podSpecCount)))
}

return allErrors
}
69 changes: 69 additions & 0 deletions internal/controller/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
Copyright 2024 IBM Corporation.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package controller

import (
"fmt"
"strings"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
)

// getPodTemplateSpec parses raw as JSON and extracts a Kueue-compatible PodTemplateSpec at the given path within it
func getPodTemplateSpec(raw []byte, path string) (*v1.PodTemplateSpec, error) {
obj := &unstructured.Unstructured{}
if _, _, err := unstructured.UnstructuredJSONScheme.Decode(raw, nil, obj); err != nil {
return nil, err
}

// Walk down the path
parts := strings.Split(path, ".")
p := obj.UnstructuredContent()
var ok bool
for i := 1; i < len(parts); i++ {
p, ok = p[parts[i]].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("path element %v not found (segment %v of %v)", parts[i], i, len(parts))
}
}

// Extract the PodSpec that should be at candidatePTS.spec
candidatePTS := p
spec, ok := candidatePTS["spec"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("content at %v does not contain a spec", path)
}
podSpec := &v1.PodSpec{}
if err := runtime.DefaultUnstructuredConverter.FromUnstructuredWithValidation(spec, podSpec, true); err != nil {
return nil, fmt.Errorf("content at %v.spec not parseable as a v1.PodSpec: %w", path, err)
}

// Construct the filtered PodTemplateSpec, copying only the metadata expected by Kueue
template := &v1.PodTemplateSpec{Spec: *podSpec}
if metadata, ok := candidatePTS["metadata"].(map[string]interface{}); ok {
if labels, ok := metadata["labels"].(map[string]string); ok {
template.ObjectMeta.Labels = labels
}
if annotations, ok := metadata["annotations"].(map[string]string); ok {
template.ObjectMeta.Annotations = annotations
}
}

return template, nil
}
35 changes: 8 additions & 27 deletions internal/controller/workload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@ package controller

import (
"fmt"
"strings"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"sigs.k8s.io/controller-runtime/pkg/client"
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
Expand Down Expand Up @@ -75,35 +71,20 @@ func (aw *AppWrapper) PodSets() []kueue.PodSet {
podSets := []kueue.PodSet{}
i := 0
for _, component := range aw.Spec.Components {
LOOP:
for _, podSet := range component.PodSets {
replicas := int32(1)
if podSet.Replicas != nil {
replicas = *podSet.Replicas
}
obj := &unstructured.Unstructured{}
if _, _, err := unstructured.UnstructuredJSONScheme.Decode(component.Template.Raw, nil, obj); err != nil {
continue LOOP // TODO handle error
template, err := getPodTemplateSpec(component.Template.Raw, podSet.Path)
if err == nil {
podSets = append(podSets, kueue.PodSet{
Name: aw.Name + "-" + fmt.Sprint(i),
Template: *template,
Count: replicas,
})
i++
}
parts := strings.Split(podSet.Path, ".")
p := obj.UnstructuredContent()
var ok bool
for i := 1; i < len(parts); i++ {
p, ok = p[parts[i]].(map[string]interface{})
if !ok {
continue LOOP // TODO handle error
}
}
var template v1.PodTemplateSpec
if err := runtime.DefaultUnstructuredConverter.FromUnstructured(p, &template); err != nil {
continue LOOP // TODO handle error
}
podSets = append(podSets, kueue.PodSet{
Name: aw.Name + "-" + fmt.Sprint(i),
Template: template,
Count: replicas,
})
i++
}
}
return podSets
Expand Down