Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions internal/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,9 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool

getRangeFunc := func(dir string) func(k string, v ast.Var) error {
return func(k string, v ast.Var) error {
tr := templater.Templater{Vars: result}
cache := &templater.Cache{Vars: result}
// Replace values
newVar := ast.Var{}
switch value := v.Value.(type) {
case string:
newVar.Value = tr.Replace(value)
default:
newVar.Value = value
}
newVar.Sh = tr.Replace(v.Sh)
newVar.Ref = v.Ref
newVar.Json = tr.Replace(v.Json)
newVar.Yaml = tr.Replace(v.Yaml)
newVar.Dir = v.Dir
newVar := templater.ReplaceVar(v, cache)
// If the variable is a reference, we can resolve it
if newVar.Ref != "" {
newVar.Value = result.Get(newVar.Ref).Value
Expand All @@ -89,7 +78,7 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool
return nil
}
// Now we can check for errors since we've handled all the cases when we don't want to evaluate
if err := tr.Err(); err != nil {
if err := cache.Err(); err != nil {
return err
}
// Evaluate JSON
Expand Down Expand Up @@ -124,9 +113,9 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool
if t != nil {
// NOTE(@andreynering): We're manually joining these paths here because
// this is the raw task, not the compiled one.
tr := templater.Templater{Vars: result}
dir := tr.Replace(t.Dir)
if err := tr.Err(); err != nil {
cache := &templater.Cache{Vars: result}
dir := templater.Replace(t.Dir, cache)
if err := cache.Err(); err != nil {
return nil, err
}
dir = filepathext.SmartJoin(c.Dir, dir)
Expand Down
106 changes: 106 additions & 0 deletions internal/deepcopy/deepcopy.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package deepcopy

import (
"reflect"
)

type Copier[T any] interface {
DeepCopy() T
}
Expand Down Expand Up @@ -33,3 +37,105 @@ func Map[K comparable, V any](orig map[K]V) map[K]V {
}
return c
}

// TraverseStringsFunc runs the given function on every string in the given
// value by traversing it recursively. If the given value is a string, the
// function will run on a copy of the string and return it. If the value is a
// struct, map or a slice, the function will recursively call itself for each
// field or element of the struct, map or slice until all strings inside the
// struct or slice are replaced.
func TraverseStringsFunc[T any](v T, fn func(v string) (string, error)) (T, error) {
original := reflect.ValueOf(v)
if original.Kind() == reflect.Invalid || !original.IsValid() {
return v, nil
}
copy := reflect.New(original.Type()).Elem()

var traverseFunc func(copy, v reflect.Value) error
traverseFunc = func(copy, v reflect.Value) error {
switch v.Kind() {

case reflect.Ptr:
// Unwrap the pointer
originalValue := v.Elem()
// If the pointer is nil, do nothing
if !originalValue.IsValid() {
return nil
}
// Create an empty copy from the original value's type
copy.Set(reflect.New(originalValue.Type()))
// Unwrap the newly created pointer and call traverseFunc recursively
if err := traverseFunc(copy.Elem(), originalValue); err != nil {
return err
}

case reflect.Interface:
// Unwrap the interface
originalValue := v.Elem()
if !originalValue.IsValid() {
return nil
}
// Create an empty copy from the original value's type
copyValue := reflect.New(originalValue.Type()).Elem()
// Unwrap the newly created pointer and call traverseFunc recursively
if err := traverseFunc(copyValue, originalValue); err != nil {
return err
}
copy.Set(copyValue)

case reflect.Struct:
// Loop over each field and call traverseFunc recursively
for i := 0; i < v.NumField(); i += 1 {
if err := traverseFunc(copy.Field(i), v.Field(i)); err != nil {
return err
}
}

case reflect.Slice:
// Create an empty copy from the original value's type
copy.Set(reflect.MakeSlice(v.Type(), v.Len(), v.Cap()))
// Loop over each element and call traverseFunc recursively
for i := 0; i < v.Len(); i += 1 {
if err := traverseFunc(copy.Index(i), v.Index(i)); err != nil {
return err
}
}

case reflect.Map:
// Create an empty copy from the original value's type
copy.Set(reflect.MakeMap(v.Type()))
// Loop over each key
for _, key := range v.MapKeys() {
// Create a copy of each map index
originalValue := v.MapIndex(key)
if originalValue.IsNil() {
continue
}
copyValue := reflect.New(originalValue.Type()).Elem()
// Call traverseFunc recursively
if err := traverseFunc(copyValue, originalValue); err != nil {
return err
}
copy.SetMapIndex(key, copyValue)
}

case reflect.String:
rv, err := fn(v.String())
if err != nil {
return err
}
copy.Set(reflect.ValueOf(rv))

default:
copy.Set(v)
}

return nil
}

if err := traverseFunc(copy, original); err != nil {
return v, err
}

return copy.Interface().(T), nil
}
8 changes: 5 additions & 3 deletions internal/output/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ package output
import (
"bytes"
"io"

"github.com/go-task/task/v3/internal/templater"
)

type Group struct {
Begin, End string
ErrorOnly bool
}

func (g Group) WrapWriter(stdOut, _ io.Writer, _ string, tmpl Templater) (io.Writer, io.Writer, CloseFunc) {
func (g Group) WrapWriter(stdOut, _ io.Writer, _ string, cache *templater.Cache) (io.Writer, io.Writer, CloseFunc) {
gw := &groupWriter{writer: stdOut}
if g.Begin != "" {
gw.begin = tmpl.Replace(g.Begin) + "\n"
gw.begin = templater.Replace(g.Begin, cache) + "\n"
}
if g.End != "" {
gw.end = tmpl.Replace(g.End) + "\n"
gw.end = templater.Replace(g.End, cache) + "\n"
}
return gw, gw, func(err error) error {
if g.ErrorOnly && err == nil {
Expand Down
4 changes: 3 additions & 1 deletion internal/output/interleaved.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package output

import (
"io"

"github.com/go-task/task/v3/internal/templater"
)

type Interleaved struct{}

func (Interleaved) WrapWriter(stdOut, stdErr io.Writer, _ string, _ Templater) (io.Writer, io.Writer, CloseFunc) {
func (Interleaved) WrapWriter(stdOut, stdErr io.Writer, _ string, _ *templater.Cache) (io.Writer, io.Writer, CloseFunc) {
return stdOut, stdErr, func(error) error { return nil }
}
10 changes: 2 additions & 8 deletions internal/output/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@ import (
"fmt"
"io"

"github.com/go-task/task/v3/internal/templater"
"github.com/go-task/task/v3/taskfile/ast"
)

// Templater executes a template engine.
// It is provided by the templater.Templater package.
type Templater interface {
// Replace replaces the provided template string with a rendered string.
Replace(tmpl string) string
}

type Output interface {
WrapWriter(stdOut, stdErr io.Writer, prefix string, tmpl Templater) (io.Writer, io.Writer, CloseFunc)
WrapWriter(stdOut, stdErr io.Writer, prefix string, cache *templater.Cache) (io.Writer, io.Writer, CloseFunc)
}

type CloseFunc func(err error) error
Expand Down
2 changes: 1 addition & 1 deletion internal/output/output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestGroup(t *testing.T) {
}

func TestGroupWithBeginEnd(t *testing.T) {
tmpl := templater.Templater{
tmpl := templater.Cache{
Vars: &ast.Vars{
OrderedMap: omap.FromMap(map[string]ast.Var{
"VAR1": {Value: "example-value"},
Expand Down
4 changes: 3 additions & 1 deletion internal/output/prefixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"fmt"
"io"
"strings"

"github.com/go-task/task/v3/internal/templater"
)

type Prefixed struct{}

func (Prefixed) WrapWriter(stdOut, _ io.Writer, prefix string, _ Templater) (io.Writer, io.Writer, CloseFunc) {
func (Prefixed) WrapWriter(stdOut, _ io.Writer, prefix string, _ *templater.Cache) (io.Writer, io.Writer, CloseFunc) {
pw := &prefixWriter{writer: stdOut, prefix: prefix}
return pw, pw, func(error) error { return pw.close() }
}
Expand Down
Loading