From b91374bdf3d74a90cc623d421fa78ae5cf9bf8ae Mon Sep 17 00:00:00 2001 From: zicla Date: Tue, 3 Jul 2018 13:45:22 +0800 Subject: [PATCH] Add the file downloader. --- rest/alien_controller.go | 23 ++- rest/bean.go | 10 +- rest/matter_service.go | 329 ++++++++++++++++++++++++++++++++++++++- rest/router.go | 17 +- rest/web_error.go | 10 ++ 5 files changed, 370 insertions(+), 19 deletions(-) create mode 100644 rest/web_error.go diff --git a/rest/alien_controller.go b/rest/alien_controller.go index 963c37b..ea5d95d 100644 --- a/rest/alien_controller.go +++ b/rest/alien_controller.go @@ -3,12 +3,12 @@ package rest import ( "fmt" "net/http" - "net/url" "os" "regexp" "strconv" - "strings" "time" + "strings" + "net/url" ) type AlienController struct { @@ -347,22 +347,21 @@ func (this *AlienController) Download(writer http.ResponseWriter, request *http. this.PanicError(err) defer diskFile.Close() - // 防止中文乱码 - fileName := url.QueryEscape(matter.Name) - mimeType := GetMimeType(fileName) - writer.Header().Set("Content-Type", mimeType) - - //如果是图片或者文本或者视频就直接打开。其余的一律以下载形式返回。 - if strings.Index(mimeType, "image") != 0 && strings.Index(mimeType, "text") != 0 && strings.Index(mimeType, "video") != 0 { - writer.Header().Set("content-disposition", "attachment; filename=\""+fileName+"\"") - } - //对图片做缩放处理。 imageProcess := request.FormValue("imageProcess") if imageProcess == "resize" { this.matterService.ResizeImage(writer, request, matter, diskFile) } else { + //如果是图片或者文本或者视频就直接打开。其余的一律以下载形式返回。 + fileName := url.QueryEscape(matter.Name) + mimeType := GetMimeType(fileName) + if strings.Index(mimeType, "image") != 0 && strings.Index(mimeType, "text") != 0 && strings.Index(mimeType, "video") != 0 { + writer.Header().Set("content-disposition", "attachment; filename=\""+fileName+"\"") + } + + this.matterService.DownloadFile(writer, request, matter, diskFile) + //显示文件大小。 //fileInfo, err := diskFile.Stat() //if err != nil { diff --git a/rest/bean.go b/rest/bean.go index a5fcc77..65edb77 100644 --- a/rest/bean.go +++ b/rest/bean.go @@ -1,8 +1,11 @@ package rest +import "net/http" + type IBean interface { Init(context *Context) PanicError(err error); + PanicWebError(msg string, code int); } type Bean struct { @@ -16,6 +19,11 @@ func (this *Bean) Init(context *Context) { //处理错误的统一方法 func (this *Bean) PanicError(err error) { if err != nil { - panic(err) + panic(&WebError{Msg: err.Error(), Code: http.StatusInternalServerError}) } } + +//处理错误的统一方法 +func (this *Bean) PanicWebError(msg string, httpStatusCode int) { + panic(&WebError{Msg: msg, Code: httpStatusCode}) +} diff --git a/rest/matter_service.go b/rest/matter_service.go index 12a9649..c91ac17 100644 --- a/rest/matter_service.go +++ b/rest/matter_service.go @@ -9,6 +9,13 @@ import ( "net/http" "github.com/disintegration/imaging" "strconv" + "mime" + "path/filepath" + "net/url" + "errors" + "time" + "fmt" + "net/textproto" ) //@Service @@ -159,9 +166,14 @@ func (this *MatterService) Upload(file multipart.File, user *User, puuid string, return matter } -//处理图片下载功能。 +//图片预处理功能。 func (this *MatterService) ResizeImage(writer http.ResponseWriter, request *http.Request, matter *Matter, diskFile *os.File) { + // 防止中文乱码 + fileName := url.QueryEscape(matter.Name) + mimeType := GetMimeType(fileName) + writer.Header().Set("Content-Type", mimeType) + //当前的文件是否是图片,只有图片才能处理。 extension := GetExtension(matter.Name) formats := map[string]imaging.Format{ @@ -252,3 +264,318 @@ func (this *MatterService) ResizeImage(writer http.ResponseWriter, request *http } } + +//检查Last-Modified头。返回true: 请求已经完成了。(言下之意,文件没有修改过) 返回false:文件修改过。 +func checkLastModified(w http.ResponseWriter, r *http.Request, modifyTime time.Time) bool { + if modifyTime.IsZero() { + return false + } + + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if t, err := time.Parse(http.TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modifyTime.Before(t.Add(1*time.Second)) { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(http.StatusNotModified) + return true + } + w.Header().Set("Last-Modified", modifyTime.UTC().Format(http.TimeFormat)) + return false +} + +// 处理ETag标签 +// checkETag implements If-None-Match and If-Range checks. +// +// The ETag or modtime must have been previously set in the +// ResponseWriter's headers. The modtime is only compared at second +// granularity and may be the zero value to mean unknown. +// +// The return value is the effective request "Range" header to use and +// whether this request is now considered done. +func checkETag(w http.ResponseWriter, r *http.Request, modtime time.Time) (rangeReq string, done bool) { + etag := w.Header().Get("Etag") + rangeReq = r.Header.Get("Range") + + // Invalidate the range request if the entity doesn't match the one + // the client was expecting. + // "If-Range: version" means "ignore the Range: header unless version matches the + // current file." + // We only support ETag versions. + // The caller must have set the ETag on the response already. + if ir := r.Header.Get("If-Range"); ir != "" && ir != etag { + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + timeMatches := false + if !modtime.IsZero() { + if t, err := http.ParseTime(ir); err == nil && t.Unix() == modtime.Unix() { + timeMatches = true + } + } + if !timeMatches { + rangeReq = "" + } + } + + if inm := r.Header.Get("If-None-Match"); inm != "" { + // Must know ETag. + if etag == "" { + return rangeReq, false + } + + // TODO(bradfitz): non-GET/HEAD requests require more work: + // sending a different status code on matches, and + // also can't use weak cache validators (those with a "W/ + // prefix). But most users of ServeContent will be using + // it on GET or HEAD, so only support those for now. + if r.Method != "GET" && r.Method != "HEAD" { + return rangeReq, false + } + + // TODO(bradfitz): deal with comma-separated or multiple-valued + // list of If-None-match values. For now just handle the common + // case of a single item. + if inm == etag || inm == "*" { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(http.StatusNotModified) + return "", true + } + } + return rangeReq, false +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 2616. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + for _, ra := range strings.Split(s[len(b):], ",") { + ra = strings.TrimSpace(ra) + if ra == "" { + continue + } + i := strings.Index(ra, "-") + if i < 0 { + return nil, errors.New("invalid range") + } + start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file. + i, err := strconv.ParseInt(end, 10, 64) + if err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i >= size || i < 0 { + return nil, errors.New("invalid range") + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + return ranges, nil +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} + +//文件下载功能。 +func (this *MatterService) DownloadFile(w http.ResponseWriter, r *http.Request, matter *Matter, content *os.File) { + + //显示文件大小。 + fileInfo, err := content.Stat() + if err != nil { + panic(err) + } + + modtime := fileInfo.ModTime() + + if checkLastModified(w, r, modtime) { + return + } + rangeReq, done := checkETag(w, r, modtime) + if done { + return + } + + code := http.StatusOK + + // From net/http/sniff.go + // The algorithm uses at most sniffLen bytes to make its decision. + const sniffLen = 512 + + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + ctypes, haveType := w.Header()["Content-Type"] + var ctype string + if !haveType { + ctype = mime.TypeByExtension(filepath.Ext(fileInfo.Name())) + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [sniffLen]byte + n, _ := io.ReadFull(content, buf[:]) + ctype = http.DetectContentType(buf[:n]) + _, err := content.Seek(0, os.SEEK_SET) // rewind to output whole file + if err != nil { + this.PanicWebError("无法准确定位文件", http.StatusInternalServerError) + return + } + } + w.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size := fileInfo.Size() + + // handle Content-Range header. + sendSize := size + var sendContent io.Reader = content + if size >= 0 { + ranges, err := parseRange(rangeReq, size) + if err != nil { + panic("range header出错") + this.PanicWebError("range header error", http.StatusRequestedRangeNotSatisfiable) + return + } + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 2616, Section 14.16: + // "When an HTTP message includes the content of a single + // range (for example, a response to a request for a + // single range, or to a request for a set of ranges + // that overlap without any holes), this content is + // transmitted with a Content-Range header, and a + // Content-Length header showing the number of bytes + // actually transferred. + // ... + // A response to a request for a single range MUST NOT + // be sent using the multipart/byteranges media type." + ra := ranges[0] + if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + this.PanicWebError(err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.length + code = http.StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = http.StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, sendContent, sendSize) + } + +} diff --git a/rest/router.go b/rest/router.go index cecedc6..73e2a9b 100644 --- a/rest/router.go +++ b/rest/router.go @@ -40,16 +40,28 @@ func (this *Router) GlobalPanicHandler(writer http.ResponseWriter, request *http var webResult *WebResult = nil if value, ok := err.(string); ok { + writer.WriteHeader(http.StatusBadRequest) webResult = &WebResult{Code: RESULT_CODE_UTIL_EXCEPTION, Msg: value} } else if value, ok := err.(int); ok { + writer.WriteHeader(http.StatusBadRequest) webResult = ConstWebResult(value) } else if value, ok := err.(*WebResult); ok { + writer.WriteHeader(http.StatusBadRequest) webResult = value } else if value, ok := err.(WebResult); ok { + writer.WriteHeader(http.StatusBadRequest) webResult = &value + } else if value, ok := err.(*WebError); ok { + writer.WriteHeader(value.Code) + webResult = &WebResult{Code: RESULT_CODE_UTIL_EXCEPTION, Msg: value.Msg} + } else if value, ok := err.(WebError); ok { + writer.WriteHeader((&value).Code) + webResult = &WebResult{Code: RESULT_CODE_UTIL_EXCEPTION, Msg: (&value).Msg} } else if value, ok := err.(error); ok { + writer.WriteHeader(http.StatusBadRequest) webResult = &WebResult{Code: RESULT_CODE_UTIL_EXCEPTION, Msg: value.Error()} } else { + writer.WriteHeader(http.StatusInternalServerError) webResult = &WebResult{Code: RESULT_CODE_UTIL_EXCEPTION, Msg: "服务器未知错误"} } @@ -60,11 +72,6 @@ func (this *Router) GlobalPanicHandler(writer http.ResponseWriter, request *http var json = jsoniter.ConfigCompatibleWithStandardLibrary b, _ := json.Marshal(webResult) - if webResult.Code == RESULT_CODE_OK { - writer.WriteHeader(http.StatusOK) - } else { - writer.WriteHeader(http.StatusBadRequest) - } fmt.Fprintf(writer, string(b)) } } diff --git a/rest/web_error.go b/rest/web_error.go new file mode 100644 index 0000000..3caf7cc --- /dev/null +++ b/rest/web_error.go @@ -0,0 +1,10 @@ +package rest + +type WebError struct { + Code int `json:"code"` + Msg string `json:"msg"` +} + +func (this *WebError) Error() string { + return this.Msg +}