address comments

This commit is contained in:
Jin Qin
2023-09-28 00:39:02 +00:00
parent 399b52f2ff
commit c50beac896
7 changed files with 70 additions and 93 deletions

View File

@@ -10,7 +10,7 @@ import (
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
"golang.org/x/oauth2/google/internal/stsexchange"
)
// now aliases time.Now for testing.
@@ -87,13 +87,13 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
}
clientAuth := sts_exchange.ClientAuthentication{
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
stsResponse, err := sts_exchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
if err != nil {
return nil, err
}

View File

@@ -15,7 +15,7 @@ import (
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
"golang.org/x/oauth2/google/internal/stsexchange"
)
const expiryDelta = 10 * time.Second
@@ -33,59 +33,11 @@ type testRefreshTokenServer struct {
Authorization string
ContentType string
Body string
ResponsePayload *sts_exchange.Response
ResponsePayload *stsexchange.Response
Response string
server *httptest.Server
}
func (trts *testRefreshTokenServer) Run(t *testing.T) (string, error) {
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}
func (trts *testRefreshTokenServer) Close() error {
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}
// Tests
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
@@ -111,18 +63,18 @@ func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
RefreshToken: "CCCCCCC",
},
}
url, err := server.Run(t)
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
@@ -153,17 +105,17 @@ func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.Run(t)
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
@@ -191,17 +143,17 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.Run(t)
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
defer server.close(t)
testCases := []struct {
name string
config Config
@@ -257,3 +209,51 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
})
}
}
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
t.Helper()
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}
func (trts *testRefreshTokenServer) close(t *testing.T) error {
t.Helper()
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}