OAS: handle urlencoded and multipart form data correctly (#730)

* Start with tests

* Expected file

* Almost works

* Make test stab;e

* 5 examples, not 6

* Add param type

* multipart example

* parsing multipart

* Commit

* Commit

* multipart

* Write encoding into schema

* Stable test

* Add reset

* Refactoring

* Refactoring

* Maintain the required flag

* lint
This commit is contained in:
Andrey Pokhilko
2022-02-03 11:46:31 +03:00
committed by GitHub
parent 78be20fe4d
commit 5934be4da6
10 changed files with 833 additions and 88 deletions

View File

@@ -3,13 +3,17 @@ package oas
import (
"encoding/json"
"errors"
"fmt"
"github.com/chanced/openapi"
"github.com/google/uuid"
"github.com/nav-inc/datetime"
"github.com/up9inc/mizu/shared/logger"
"io"
"io/ioutil"
"mime"
"mime/multipart"
"net/textproto"
"net/url"
"sort"
"strconv"
"strings"
"sync"
@@ -126,6 +130,7 @@ func (g *SpecGen) GetSpec() (*openapi.OpenAPI, error) {
func suggestTags(oas *openapi.OpenAPI) {
paths := getPathsKeys(oas.Paths.Items)
sort.Strings(paths) // make it stable in case of multiple candidates
for len(paths) > 0 {
group := make([]string, 0)
group = append(group, paths[0])
@@ -155,21 +160,9 @@ func suggestTags(oas *openapi.OpenAPI) {
}
}
}
//groups[common] = group
}
}
func deleteFromSlice(s []string, val string) []string {
temp := s[:0]
for _, x := range s {
if x != val {
temp = append(temp, x)
}
}
return temp
}
func getPathsKeys(mymap map[openapi.PathValue]*openapi.PathObj) []string {
keys := make([]string, len(mymap))
@@ -364,7 +357,7 @@ func handleRequest(req *har.Request, opObj *openapi.Operation, isSuccess bool) e
}
if reqBody != nil {
reqCtype := getReqCtype(req)
reqCtype, _ := getReqCtype(req)
reqMedia, err := fillContent(reqResp{Req: req}, reqBody.Content, reqCtype)
if err != nil {
return err
@@ -471,12 +464,121 @@ func fillContent(reqResp reqResp, respContent openapi.Content, ctype string) (*o
}
}
content.Example = exampleMsg
if ctype == "application/x-www-form-urlencoded" && reqResp.Req != nil {
handleFormDataUrlencoded(text, content)
} else if strings.HasPrefix(ctype, "multipart/form-data") && reqResp.Req != nil {
_, params := getReqCtype(reqResp.Req)
handleFormDataMultipart(text, content, params)
}
if content.Example == nil && len(exampleMsg) > len(content.Example) {
content.Example = exampleMsg
}
}
return respContent[ctype], nil
}
func handleFormDataUrlencoded(text string, content *openapi.MediaType) {
formData, err := url.ParseQuery(text)
if err != nil {
logger.Log.Warningf("Could not decode urlencoded: %s", err)
return
}
parts := make([]PartWithBody, 0)
for name, vals := range formData {
for _, val := range vals {
part := new(multipart.Part)
part.Header = textproto.MIMEHeader{}
part.Header.Add("Content-Disposition", "form-data; name=\""+name+"\";")
parts = append(parts, PartWithBody{part: part, body: []byte(val)})
}
}
handleFormData(content, parts)
}
func handleFormData(content *openapi.MediaType, parts []PartWithBody) {
hadSchema := true
if content.Schema == nil {
hadSchema = false // will use it for required flags
content.Schema = new(openapi.SchemaObj)
content.Schema.Type = openapi.Types{openapi.TypeObject}
content.Schema.Properties = openapi.Schemas{}
}
props := &content.Schema.Properties
seenNames := map[string]struct{}{} // set equivalent in Go, yikes
for _, pwb := range parts {
name := pwb.part.FormName()
seenNames[name] = struct{}{}
existing, found := (*props)[name]
if !found {
existing = new(openapi.SchemaObj)
existing.Type = openapi.Types{openapi.TypeString}
(*props)[name] = existing
ctype := pwb.part.Header.Get("content-type")
if ctype != "" {
if existing.Keywords == nil {
existing.Keywords = map[string]json.RawMessage{}
}
existing.Keywords["contentMediaType"], _ = json.Marshal(ctype)
}
}
addSchemaExample(existing, string(pwb.body))
}
// handle required flag
if content.Schema.Required == nil {
if !hadSchema {
content.Schema.Required = make([]string, 0)
for name := range seenNames {
content.Schema.Required = append(content.Schema.Required, name)
}
} // else it's a known schema with no required fields
} else {
content.Schema.Required = intersectSliceWithMap(content.Schema.Required, seenNames)
}
}
type PartWithBody struct {
part *multipart.Part
body []byte
}
func handleFormDataMultipart(text string, content *openapi.MediaType, ctypeParams map[string]string) {
boundary, ok := ctypeParams["boundary"]
if !ok {
logger.Log.Errorf("Multipart header has no boundary")
return
}
mpr := multipart.NewReader(strings.NewReader(text), boundary)
parts := make([]PartWithBody, 0)
for {
part, err := mpr.NextPart()
if err == io.EOF {
break
}
if err != nil {
logger.Log.Errorf("Cannot parse multipart body: %v", err)
break
}
defer part.Close()
body, err := ioutil.ReadAll(part)
if err != nil {
logger.Log.Errorf("Error reading multipart Part %s: %v", part.Header, err)
}
parts = append(parts, PartWithBody{part: part, body: body})
}
handleFormData(content, parts)
}
func getRespCtype(resp *har.Response) string {
var ctype string
ctype = resp.Content.MimeType
@@ -493,8 +595,7 @@ func getRespCtype(resp *har.Response) string {
return mediaType
}
func getReqCtype(req *har.Request) string {
var ctype string
func getReqCtype(req *har.Request) (ctype string, params map[string]string) {
ctype = req.PostData.MimeType
for _, hdr := range req.Headers {
if strings.ToLower(hdr.Name) == "content-type" {
@@ -502,11 +603,12 @@ func getReqCtype(req *har.Request) string {
}
}
mediaType, _, err := mime.ParseMediaType(ctype)
mediaType, params, err := mime.ParseMediaType(ctype)
if err != nil {
return ""
logger.Log.Errorf("Cannot parse Content-Type header %q: %v", ctype, err)
return "", map[string]string{}
}
return mediaType
return mediaType, params
}
func getResponseObj(resp *har.Response, opObj *openapi.Operation, isSuccess bool) (*openapi.ResponseObj, error) {
@@ -591,56 +693,3 @@ func getOpObj(pathObj *openapi.PathObj, method string, createIfNone bool) (*open
return *op, isMissing, nil
}
type CounterMaps struct {
counterTotal Counter
counterMapTotal CounterMap
}
func (m *CounterMaps) processOp(opObj *openapi.Operation) error {
if _, ok := opObj.Extensions.Extension(CountersTotal); ok {
counter := new(Counter)
err := opObj.Extensions.DecodeExtension(CountersTotal, counter)
if err != nil {
return err
}
m.counterTotal.addOther(counter)
opObj.Description = setCounterMsgIfOk(opObj.Description, counter)
}
if _, ok := opObj.Extensions.Extension(CountersPerSource); ok {
counterMap := new(CounterMap)
err := opObj.Extensions.DecodeExtension(CountersPerSource, counterMap)
if err != nil {
return err
}
m.counterMapTotal.addOther(counterMap)
}
return nil
}
func (m *CounterMaps) processOas(oas *openapi.OpenAPI) error {
if oas.Extensions == nil {
oas.Extensions = openapi.Extensions{}
}
err := oas.Extensions.SetExtension(CountersTotal, m.counterTotal)
if err != nil {
return err
}
err = oas.Extensions.SetExtension(CountersPerSource, m.counterMapTotal)
if err != nil {
return nil
}
return nil
}
func setCounterMsgIfOk(oldStr string, cnt *Counter) string {
tpl := "Mizu observed %d entries (%d failed), at %.3f hits/s, average response time is %.3f seconds"
if oldStr == "" || (strings.HasPrefix(oldStr, "Mizu ") && strings.HasSuffix(oldStr, " seconds")) {
return fmt.Sprintf(tpl, cnt.Entries, cnt.Failures, cnt.SumDuration/float64(cnt.Entries), cnt.SumRT/float64(cnt.Entries))
}
return oldStr
}