diff --git a/google/google.go b/google/google.go index 328ee96..41ced10 100644 --- a/google/google.go +++ b/google/google.go @@ -123,7 +123,7 @@ type credentialsFile struct { ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"` CredentialSource externalaccount.CredentialSource `json:"credential_source"` QuotaProjectID string `json:"quota_project_id"` - WorkforcePoolUserProject string `json:"workforce_pool_user_project"` + WorkforcePoolUserProject string `json:"workforce_pool_user_project"` } func (f *credentialsFile) jwtConfig(scopes []string, subject string) *jwt.Config { @@ -177,7 +177,7 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar CredentialSource: f.CredentialSource, QuotaProjectID: f.QuotaProjectID, Scopes: params.Scopes, - WorkforcePoolUserProject: f.WorkforcePoolUserProject, + WorkforcePoolUserProject: f.WorkforcePoolUserProject, } return cfg.TokenSource(ctx) case "": diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 088e174..38fb253 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -78,6 +78,7 @@ var ( regexp.MustCompile(`^iamcredentials\.[^\.\s\/\\]+\.googleapis\.com$`), regexp.MustCompile(`^[^\.\s\/\\]+-iamcredentials\.googleapis\.com$`), } + validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`) ) func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool { @@ -91,14 +92,17 @@ func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool { toTest := parsed.Host for _, pattern := range patterns { - - if valid := pattern.MatchString(toTest); valid { + if pattern.MatchString(toTest) { return true } } return false } +func validateWorkforceAudience(input string) bool { + return validWorkforceAudiencePattern.MatchString(input) +} + // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns, "https") @@ -120,6 +124,13 @@ func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Re } } + if c.WorkforcePoolUserProject != "" { + valid := validateWorkforceAudience(c.Audience) + if !valid { + return nil, fmt.Errorf("oauth2/google: invalid Workforce Pool Audience provided while constructing tokenSource") + } + } + ts := tokenSource{ ctx: ctx, conf: c, @@ -229,10 +240,10 @@ func (ts tokenSource) Token() (*oauth2.Token, error) { ClientID: conf.ClientID, ClientSecret: conf.ClientSecret, } - var options map[string]string - if (ts.Config.WorkforcePoolUserProject != "") { - options = map[string]string{ - "userProject": ts.Config.WorkforcePoolUserProject, + var options map[string]interface{} + if conf.WorkforcePoolUserProject != "" { + options = map[string]interface{}{ + "userProject": conf.WorkforcePoolUserProject, } } stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options) diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index b1131d6..9367c49 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -210,3 +210,41 @@ func TestValidateURLImpersonateURL(t *testing.T) { }) } } + +func TestWorkforcePoolCreation(t *testing.T) { + var audienceValidatyTests = []struct { + audience string + expectSuccess bool + }{ + {"//iam.googleapis.com/locations/global/workforcePools/pool-id/providers/provider-id", true}, + {"//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", true}, + {"//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id", true}, + {"identitynamespace:1f12345:my_provider", false}, + {"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/projects/123456/locations/eu/workloadIdentityPools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/workforcePools/providers/provider-id", false}, + {"//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", false}, + {"//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", false}, + } + + ctx := context.Background() + for _, tt := range audienceValidatyTests { + t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. + config := testConfig + config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL + config.ServiceAccountImpersonationURL = "https://iamcredentials.googleapis.com" + config.Audience = tt.audience + config.WorkforcePoolUserProject = "myProject" + _, err := config.TokenSource(ctx) + + if tt.expectSuccess && err != nil { + t.Errorf("got %v but want nil", err) + } else if !tt.expectSuccess && err == nil { + t.Errorf("got nil but expected an error") + } + }) + } +}