diff --git a/xmltemplate/xmltemplate.go b/xmltemplate/xmltemplate.go index bf560f9..bfc369b 100644 --- a/xmltemplate/xmltemplate.go +++ b/xmltemplate/xmltemplate.go @@ -6,7 +6,8 @@ import ( "github.com/Graylog2/graylog-project-cli/config" "github.com/Graylog2/graylog-project-cli/logger" p "github.com/Graylog2/graylog-project-cli/project" - "io/ioutil" + "github.com/hashicorp/go-version" + "os" "text/template" ) @@ -49,12 +50,17 @@ func mavenAssemblies(project p.Project) map[string][]Assembly { func WriteXmlFile(config config.Config, project p.Project, templateFile string, outputFile string) { logger.Info("Generating %v", outputFile) - bts, err := ioutil.ReadFile(templateFile) + bts, err := os.ReadFile(templateFile) if err != nil { logger.Fatal("Error reading %v: %v", templateFile, err) } - tmpl, err := template.New(templateFile).Parse(string(bts)) + serverVersion, err := version.NewVersion(project.Server.Version()) + if err != nil { + logger.Fatal("Error parsing server version %q: %v", project.Server.Version(), err) + } + + tmpl, err := template.New(templateFile).Funcs(versionTemplateFuncs(serverVersion)).Parse(string(bts)) if err != nil { logger.Fatal("Error parsing template: %v", err) } @@ -73,7 +79,26 @@ func WriteXmlFile(config config.Config, project p.Project, templateFile string, logger.Fatal("Unable to execute template: %v", err) } - if err := ioutil.WriteFile(outputFile, buf.Bytes(), 0644); err != nil { + if err := os.WriteFile(outputFile, buf.Bytes(), 0644); err != nil { logger.Fatal("Unable to write file %v: %v", outputFile, err) } } + +func versionTemplateFuncs(serverVersion *version.Version) template.FuncMap { + compare := func(compareFunc func(*version.Version) bool) func(any) (bool, error) { + return func(versionValue any) (bool, error) { + givenVersion, err := version.NewVersion(fmt.Sprintf("%s", versionValue)) + if err != nil { + return false, fmt.Errorf("couldn't parse version %q: %w", versionValue, err) + } + return compareFunc(givenVersion), nil + } + } + + return template.FuncMap{ + "versionGt": compare(serverVersion.GreaterThan), + "versionGte": compare(serverVersion.GreaterThanOrEqual), + "versionLt": compare(serverVersion.LessThan), + "versionLte": compare(serverVersion.LessThanOrEqual), + } +} diff --git a/xmltemplate/xmltemplate_test.go b/xmltemplate/xmltemplate_test.go new file mode 100644 index 0000000..345d83d --- /dev/null +++ b/xmltemplate/xmltemplate_test.go @@ -0,0 +1,58 @@ +package xmltemplate + +import ( + "bytes" + "fmt" + "github.com/hashicorp/go-version" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "text/template" +) + +func TestVersionTemplateFuncs2(t *testing.T) { + tests := []struct { + serverVersion string + function string + version string + rendered bool + }{ + {"1.2.3", "versionGt", "1.2.3", false}, + {"1.2.3", "versionGt", "1.3.0", false}, + {"1.2.3", "versionGt", "1.2.0", true}, + + {"1.2.3", "versionGte", "1.2.3", true}, + {"1.2.3", "versionGte", "1.2.0", true}, + {"1.2.3", "versionGte", "1.3.0", false}, + + {"1.2.3", "versionLt", "1.2.3", false}, + {"1.2.3", "versionLt", "1.3.0", true}, + {"1.2.3", "versionLt", "1.2.0", false}, + + {"1.2.3", "versionLte", "1.2.3", true}, + {"1.2.3", "versionLte", "1.2.0", false}, + {"1.2.3", "versionLte", "1.3.0", true}, + } + for _, test := range tests { + testName := fmt.Sprintf("%s-%s-%s-%t", test.function, test.serverVersion, test.version, test.rendered) + + t.Run(testName, func(t *testing.T) { + serverVersion, err := version.NewVersion(test.serverVersion) + require.Nil(t, err) + + tmpl, err := template.New(testName).Funcs(versionTemplateFuncs(serverVersion)). + Parse(fmt.Sprintf(`{{ if %s "%s" }}RENDERED{{ end }}`, test.function, test.version)) + require.Nil(t, err) + + var out bytes.Buffer + + require.Nil(t, tmpl.Execute(&out, nil)) + + if test.rendered { + assert.Contains(t, out.String(), "RENDERED") + } else { + assert.NotContains(t, out.String(), "RENDERED") + } + }) + } +}