From daa58478596b409309df241caf786b50ccd93818 Mon Sep 17 00:00:00 2001 From: zicla Date: Wed, 24 Apr 2019 19:36:10 +0800 Subject: [PATCH] Refine the directory structure of this project. --- rest/alien_controller.go | 13 +- rest/alien_service.go | 5 +- rest/base_controller.go | 23 +- rest/bean.go | 11 +- rest/dashboard_controller.go | 7 +- rest/dav_controller.go | 3 +- rest/download/download.go | 364 ++++++++++++++++++++++++++++++++ rest/footprint_controller.go | 7 +- rest/image_cache_controller.go | 9 +- rest/image_cache_service.go | 9 +- rest/install_controller.go | 15 +- rest/matter_controller.go | 21 +- rest/matter_dao.go | 3 +- rest/matter_model.go | 4 +- rest/matter_service.go | 338 +---------------------------- rest/preference_controller.go | 7 +- rest/preference_dao.go | 3 +- rest/{ => result}/web_result.go | 2 +- rest/router.go | 24 ++- rest/user_controller.go | 22 +- rest/user_service.go | 40 ++++ rest/{ => util}/util_mime.go | 2 +- 22 files changed, 517 insertions(+), 415 deletions(-) create mode 100644 rest/download/download.go rename rest/{ => result}/web_result.go (99%) rename rest/{ => util}/util_mime.go (99%) diff --git a/rest/alien_controller.go b/rest/alien_controller.go index 796a0f7..58c81a9 100644 --- a/rest/alien_controller.go +++ b/rest/alien_controller.go @@ -6,6 +6,7 @@ import ( "regexp" "strconv" "strings" + "tank/rest/result" "time" ) @@ -137,7 +138,7 @@ func (this *AlienController) CheckRequestUser(writer http.ResponseWriter, reques } //系统中的用户x要获取一个UploadToken,用于提供给x信任的用户上传文件。 -func (this *AlienController) FetchUploadToken(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *AlienController) FetchUploadToken(writer http.ResponseWriter, request *http.Request) *result.WebResult { //文件名。 filename := request.FormValue("filename") @@ -219,7 +220,7 @@ func (this *AlienController) FetchUploadToken(writer http.ResponseWriter, reques } //系统中的用户x 拿着某个文件的uuid来确认是否其信任的用户已经上传好了。 -func (this *AlienController) Confirm(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *AlienController) Confirm(writer http.ResponseWriter, request *http.Request) *result.WebResult { matterUuid := request.FormValue("matterUuid") if matterUuid == "" { @@ -237,7 +238,7 @@ func (this *AlienController) Confirm(writer http.ResponseWriter, request *http.R } //系统中的用户x 信任的用户上传文件。这个接口需要支持跨域。 -func (this *AlienController) Upload(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *AlienController) Upload(writer http.ResponseWriter, request *http.Request) *result.WebResult { //允许跨域请求。 this.allowCORS(writer) if request.Method == "OPTIONS" { @@ -288,7 +289,7 @@ func (this *AlienController) Upload(writer http.ResponseWriter, request *http.Re } //给一个指定的url,从该url中去拉取文件回来。此处采用uploadToken的模式。 -func (this *AlienController) CrawlToken(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *AlienController) CrawlToken(writer http.ResponseWriter, request *http.Request) *result.WebResult { //允许跨域请求。 this.allowCORS(writer) if request.Method == "OPTIONS" { @@ -336,7 +337,7 @@ func (this *AlienController) CrawlToken(writer http.ResponseWriter, request *htt } //通过一个url直接上传,无需借助uploadToken. -func (this *AlienController) CrawlDirect(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *AlienController) CrawlDirect(writer http.ResponseWriter, request *http.Request) *result.WebResult { //文件名。 filename := request.FormValue("filename") @@ -377,7 +378,7 @@ func (this *AlienController) CrawlDirect(writer http.ResponseWriter, request *ht } //系统中的用户x要获取一个DownloadToken,用于提供给x信任的用户下载文件。 -func (this *AlienController) FetchDownloadToken(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *AlienController) FetchDownloadToken(writer http.ResponseWriter, request *http.Request) *result.WebResult { matterUuid := request.FormValue("matterUuid") if matterUuid == "" { diff --git a/rest/alien_service.go b/rest/alien_service.go index fa4e459..e6e04fa 100644 --- a/rest/alien_service.go +++ b/rest/alien_service.go @@ -3,6 +3,7 @@ package rest import ( "fmt" "net/http" + "tank/rest/result" "time" ) @@ -96,7 +97,7 @@ func (this *AlienService) PreviewOrDownload( tokenUser := this.userDao.CheckByUuid(downloadToken.UserUuid) if matter.UserUuid != tokenUser.Uuid { - panic(CODE_WRAPPER_UNAUTHORIZED) + panic(result.CODE_WRAPPER_UNAUTHORIZED) } //下载之后立即过期掉。如果是分块下载的,必须以最终获取到完整的数据为准。 @@ -108,7 +109,7 @@ func (this *AlienService) PreviewOrDownload( //判断文件的所属人是否正确 operator := this.findUser(writer, request) if operator == nil || (operator.Role != USER_ROLE_ADMINISTRATOR && matter.UserUuid != operator.Uuid) { - panic(CODE_WRAPPER_UNAUTHORIZED) + panic(result.CODE_WRAPPER_UNAUTHORIZED) } } diff --git a/rest/base_controller.go b/rest/base_controller.go index 5afe763..a43ab29 100644 --- a/rest/base_controller.go +++ b/rest/base_controller.go @@ -5,6 +5,7 @@ import ( "github.com/json-iterator/go" "go/types" "net/http" + "tank/rest/result" ) type IController interface { @@ -49,13 +50,13 @@ func (this *BaseController) HandleRoutes(writer http.ResponseWriter, request *ht } //需要进行登录验证的wrap包装 -func (this *BaseController) Wrap(f func(writer http.ResponseWriter, request *http.Request) *WebResult, qualifiedRole string) func(w http.ResponseWriter, r *http.Request) { +func (this *BaseController) Wrap(f func(writer http.ResponseWriter, request *http.Request) *result.WebResult, qualifiedRole string) func(w http.ResponseWriter, r *http.Request) { return func(writer http.ResponseWriter, request *http.Request) { //writer和request赋值给自己。 - var webResult *WebResult = nil + var webResult *result.WebResult = nil //只有游客接口不需要登录 if qualifiedRole != USER_ROLE_GUEST { @@ -63,10 +64,10 @@ func (this *BaseController) Wrap(f func(writer http.ResponseWriter, request *htt if user.Status == USER_STATUS_DISABLED { //判断用户是否被禁用。 - webResult = ConstWebResult(CODE_WRAPPER_USER_DISABLED) + webResult = result.ConstWebResult(result.CODE_WRAPPER_USER_DISABLED) } else { if qualifiedRole == USER_ROLE_ADMINISTRATOR && user.Role != USER_ROLE_ADMINISTRATOR { - webResult = ConstWebResult(CODE_WRAPPER_UNAUTHORIZED) + webResult = result.ConstWebResult(result.CODE_WRAPPER_UNAUTHORIZED) } else { webResult = f(writer, request) } @@ -86,7 +87,7 @@ func (this *BaseController) Wrap(f func(writer http.ResponseWriter, request *htt this.PanicError(err) - writer.WriteHeader(FetchHttpStatus(webResult.Code)) + writer.WriteHeader(result.FetchHttpStatus(webResult.Code)) _, err = fmt.Fprintf(writer, string(b)) this.PanicError(err) @@ -99,20 +100,20 @@ func (this *BaseController) Wrap(f func(writer http.ResponseWriter, request *htt } //返回成功的结果。支持放置三种类型 1.字符串 2. WebResult对象 3.空指针 4.任意类型 -func (this *BaseController) Success(data interface{}) *WebResult { - var webResult *WebResult = nil +func (this *BaseController) Success(data interface{}) *result.WebResult { + var webResult *result.WebResult = nil if value, ok := data.(string); ok { //返回一句普通的消息 - webResult = &WebResult{Code: CODE_WRAPPER_OK.Code, Msg: value} - } else if value, ok := data.(*WebResult); ok { + webResult = &result.WebResult{Code: result.CODE_WRAPPER_OK.Code, Msg: value} + } else if value, ok := data.(*result.WebResult); ok { //返回一个webResult对象 webResult = value } else if _, ok := data.(types.Nil); ok { //返回一个空指针 - webResult = ConstWebResult(CODE_WRAPPER_OK) + webResult = result.ConstWebResult(result.CODE_WRAPPER_OK) } else { //返回的类型不明确。 - webResult = &WebResult{Code: CODE_WRAPPER_OK.Code, Data: data} + webResult = &result.WebResult{Code: result.CODE_WRAPPER_OK.Code, Data: data} } return webResult } diff --git a/rest/bean.go b/rest/bean.go index a730455..c55275b 100644 --- a/rest/bean.go +++ b/rest/bean.go @@ -3,6 +3,7 @@ package rest import ( "fmt" "net/http" + "tank/rest/result" ) type IBean interface { @@ -42,22 +43,22 @@ func (this *Bean) PanicError(err error) { //请求参数有问题 func (this *Bean) PanicBadRequest(format string, v ...interface{}) { - panic(CustomWebResult(CODE_WRAPPER_BAD_REQUEST, fmt.Sprintf(format, v...))) + panic(result.CustomWebResult(result.CODE_WRAPPER_BAD_REQUEST, fmt.Sprintf(format, v...))) } //没有权限 func (this *Bean) PanicUnauthorized(format string, v ...interface{}) { - panic(CustomWebResult(CODE_WRAPPER_UNAUTHORIZED, fmt.Sprintf(format, v...))) + panic(result.CustomWebResult(result.CODE_WRAPPER_UNAUTHORIZED, fmt.Sprintf(format, v...))) } //没有找到 func (this *Bean) PanicNotFound(format string, v ...interface{}) { - panic(CustomWebResult(CODE_WRAPPER_NOT_FOUND, fmt.Sprintf(format, v...))) + panic(result.CustomWebResult(result.CODE_WRAPPER_NOT_FOUND, fmt.Sprintf(format, v...))) } //服务器内部出问题 func (this *Bean) PanicServer(format string, v ...interface{}) { - panic(CustomWebResult(CODE_WRAPPER_SERVER, fmt.Sprintf(format, v...))) + panic(result.CustomWebResult(result.CODE_WRAPPER_SERVER, fmt.Sprintf(format, v...))) } //能找到一个user就找到一个 @@ -94,7 +95,7 @@ func (this *Bean) findUser(writer http.ResponseWriter, request *http.Request) *U //获取当前登录的用户,找不到就返回登录错误 func (this *Bean) checkUser(writer http.ResponseWriter, request *http.Request) *User { if this.findUser(writer, request) == nil { - panic(ConstWebResult(CODE_WRAPPER_LOGIN)) + panic(result.ConstWebResult(result.CODE_WRAPPER_LOGIN)) } else { return this.findUser(writer, request) } diff --git a/rest/dashboard_controller.go b/rest/dashboard_controller.go index 0c5b28a..d4b9534 100644 --- a/rest/dashboard_controller.go +++ b/rest/dashboard_controller.go @@ -3,6 +3,7 @@ package rest import ( "net/http" "strconv" + "tank/rest/result" ) type DashboardController struct { @@ -41,14 +42,14 @@ func (this *DashboardController) RegisterRoutes() map[string]func(writer http.Re } //过去七天分时调用量 -func (this *DashboardController) InvokeList(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *DashboardController) InvokeList(writer http.ResponseWriter, request *http.Request) *result.WebResult { return this.Success("") } //按照分页的方式获取某个图片缓存夹下图片缓存和子图片缓存夹的列表,通常情况下只有一页。 -func (this *DashboardController) Page(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *DashboardController) Page(writer http.ResponseWriter, request *http.Request) *result.WebResult { //如果是根目录,那么就传入root. pageStr := request.FormValue("page") @@ -96,7 +97,7 @@ func (this *DashboardController) Page(writer http.ResponseWriter, request *http. } -func (this *DashboardController) ActiveIpTop10(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *DashboardController) ActiveIpTop10(writer http.ResponseWriter, request *http.Request) *result.WebResult { list := this.dashboardDao.ActiveIpTop10() return this.Success(list) } diff --git a/rest/dav_controller.go b/rest/dav_controller.go index f709e37..70868ea 100644 --- a/rest/dav_controller.go +++ b/rest/dav_controller.go @@ -7,6 +7,7 @@ import ( "net/http" "regexp" "strings" + "tank/rest/result" ) /** @@ -77,7 +78,7 @@ func (this *DavController) CheckCurrentUser(writer http.ResponseWriter, request //要求前端使用Basic的形式授权 writer.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - panic(ConstWebResult(CODE_WRAPPER_LOGIN)) + panic(result.ConstWebResult(result.CODE_WRAPPER_LOGIN)) } diff --git a/rest/download/download.go b/rest/download/download.go new file mode 100644 index 0000000..3ac46aa --- /dev/null +++ b/rest/download/download.go @@ -0,0 +1,364 @@ +package download + +import ( + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" + "os" + "strconv" + "strings" + "tank/rest/result" + "tank/rest/util" + "time" +) + +// HttpRange specifies the byte range to be sent to the client. +type HttpRange struct { + start int64 + length int64 +} + +func (r HttpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r HttpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// CountingWriter counts how many bytes have been written to it. +type CountingWriter int64 + +func (w *CountingWriter) Write(p []byte) (n int, err error) { + *w += CountingWriter(len(p)) + return len(p), nil +} + +//检查Last-Modified头。返回true: 请求已经完成了。(言下之意,文件没有修改过) 返回false:文件修改过。 +func CheckLastModified(w http.ResponseWriter, r *http.Request, modifyTime time.Time) bool { + if modifyTime.IsZero() { + return false + } + + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if t, err := time.Parse(http.TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modifyTime.Before(t.Add(1*time.Second)) { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(http.StatusNotModified) + return true + } + w.Header().Set("Last-Modified", modifyTime.UTC().Format(http.TimeFormat)) + return false +} + +// 处理ETag标签 +// CheckETag implements If-None-Match and If-Range checks. +// +// The ETag or modtime must have been previously set in the +// ResponseWriter's headers. The modtime is only compared at second +// granularity and may be the zero value to mean unknown. +// +// The return value is the effective request "Range" header to use and +// whether this request is now considered done. +func CheckETag(w http.ResponseWriter, r *http.Request, modtime time.Time) (rangeReq string, done bool) { + etag := w.Header().Get("Etag") + rangeReq = r.Header.Get("Range") + + // Invalidate the range request if the entity doesn't match the one + // the client was expecting. + // "If-Range: version" means "ignore the Range: header unless version matches the + // current file." + // We only support ETag versions. + // The caller must have set the ETag on the response already. + if ir := r.Header.Get("If-Range"); ir != "" && ir != etag { + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + timeMatches := false + if !modtime.IsZero() { + if t, err := http.ParseTime(ir); err == nil && t.Unix() == modtime.Unix() { + timeMatches = true + } + } + if !timeMatches { + rangeReq = "" + } + } + + if inm := r.Header.Get("If-None-Match"); inm != "" { + // Must know ETag. + if etag == "" { + return rangeReq, false + } + + // (bradfitz): non-GET/HEAD requests require more work: + // sending a different status code on matches, and + // also can't use weak cache validators (those with a "W/ + // prefix). But most users of ServeContent will be using + // it on GET or HEAD, so only support those for now. + if r.Method != "GET" && r.Method != "HEAD" { + return rangeReq, false + } + + // (bradfitz): deal with comma-separated or multiple-valued + // list of If-None-match values. For now just handle the common + // case of a single item. + if inm == etag || inm == "*" { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(http.StatusNotModified) + return "", true + } + } + return rangeReq, false +} + +// ParseRange parses a Range header string as per RFC 2616. +func ParseRange(s string, size int64) ([]HttpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []HttpRange + for _, ra := range strings.Split(s[len(b):], ",") { + ra = strings.TrimSpace(ra) + if ra == "" { + continue + } + i := strings.Index(ra, "-") + if i < 0 { + return nil, errors.New("invalid range") + } + start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) + var r HttpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file. + i, err := strconv.ParseInt(end, 10, 64) + if err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i >= size || i < 0 { + return nil, errors.New("invalid range") + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + return ranges, nil +} + +// RangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func RangesMIMESize(ranges []HttpRange, contentType string, contentSize int64) (encSize int64) { + var w CountingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + _, e := mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + + PanicError(e) + + encSize += ra.length + } + e := mw.Close() + PanicError(e) + encSize += int64(w) + return +} + +func SumRangesSize(ranges []HttpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} + + +func PanicError(err error) { + if err != nil { + panic(err) + } +} + + + +//文件下载。具有进度功能。 +//下载功能参考:https://github.com/Masterminds/go-fileserver +func DownloadFile( + writer http.ResponseWriter, + request *http.Request, + filePath string, + filename string, + withContentDisposition bool) { + + diskFile, err := os.Open(filePath) + PanicError(err) + + defer func() { + e := diskFile.Close() + PanicError(e) + }() + + + //根据参数添加content-disposition。该Header会让浏览器自动下载,而不是预览。 + if withContentDisposition { + fileName := url.QueryEscape(filename) + writer.Header().Set("content-disposition", "attachment; filename=\""+fileName+"\"") + } + + //显示文件大小。 + fileInfo, err := diskFile.Stat() + if err != nil { + panic("无法从磁盘中获取文件信息") + } + + modifyTime := fileInfo.ModTime() + + if CheckLastModified(writer, request, modifyTime) { + return + } + rangeReq, done := CheckETag(writer, request, modifyTime) + if done { + return + } + + code := http.StatusOK + + // From net/http/sniff.go + // The algorithm uses at most sniffLen bytes to make its decision. + const sniffLen = 512 + + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + ctypes, haveType := writer.Header()["Content-Type"] + var ctype string + if !haveType { + //使用mimeUtil来获取mime + ctype = util.GetFallbackMimeType(filename, "") + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [sniffLen]byte + n, _ := io.ReadFull(diskFile, buf[:]) + ctype = http.DetectContentType(buf[:n]) + _, err := diskFile.Seek(0, os.SEEK_SET) // rewind to output whole file + if err != nil { + panic("无法准确定位文件") + } + } + writer.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size := fileInfo.Size() + + // handle Content-Range header. + sendSize := size + var sendContent io.Reader = diskFile + if size >= 0 { + ranges, err := ParseRange(rangeReq, size) + if err != nil { + panic(result.CustomWebResult(result.CODE_WRAPPER_RANGE_NOT_SATISFIABLE, "range header出错")) + } + if SumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 2616, Section 14.16: + // "When an HTTP message includes the content of a single + // range (for example, a response to a request for a + // single range, or to a request for a set of ranges + // that overlap without any holes), this content is + // transmitted with a Content-Range header, and a + // Content-Length header showing the number of bytes + // actually transferred. + // ... + // A response to a request for a single range MUST NOT + // be sent using the multipart/byteranges media type." + ra := ranges[0] + if _, err := diskFile.Seek(ra.start, io.SeekStart); err != nil { + panic(result.CustomWebResult(result.CODE_WRAPPER_RANGE_NOT_SATISFIABLE, "range header出错")) + } + sendSize = ra.length + code = http.StatusPartialContent + writer.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = RangesMIMESize(ranges, ctype, size) + code = http.StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + writer.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := diskFile.Seek(ra.start, io.SeekStart); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, diskFile, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } + + writer.Header().Set("Accept-Ranges", "bytes") + if writer.Header().Get("Content-Encoding") == "" { + writer.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + } + + writer.WriteHeader(code) + + if request.Method != "HEAD" { + io.CopyN(writer, sendContent, sendSize) + } + +} diff --git a/rest/footprint_controller.go b/rest/footprint_controller.go index d23cf78..c221a5b 100644 --- a/rest/footprint_controller.go +++ b/rest/footprint_controller.go @@ -3,6 +3,7 @@ package rest import ( "net/http" "strconv" + "tank/rest/result" ) type FootprintController struct { @@ -42,7 +43,7 @@ func (this *FootprintController) RegisterRoutes() map[string]func(writer http.Re } //查看详情。 -func (this *FootprintController) Detail(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *FootprintController) Detail(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") if uuid == "" { @@ -64,7 +65,7 @@ func (this *FootprintController) Detail(writer http.ResponseWriter, request *htt } //按照分页的方式查询 -func (this *FootprintController) Page(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *FootprintController) Page(writer http.ResponseWriter, request *http.Request) *result.WebResult { //如果是根目录,那么就传入root. pageStr := request.FormValue("page") @@ -108,7 +109,7 @@ func (this *FootprintController) Page(writer http.ResponseWriter, request *http. } //删除一条记录 -func (this *FootprintController) Delete(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *FootprintController) Delete(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") if uuid == "" { diff --git a/rest/image_cache_controller.go b/rest/image_cache_controller.go index c8811c4..6915659 100644 --- a/rest/image_cache_controller.go +++ b/rest/image_cache_controller.go @@ -4,6 +4,7 @@ import ( "net/http" "strconv" "strings" + "tank/rest/result" ) type ImageCacheController struct { @@ -44,7 +45,7 @@ func (this *ImageCacheController) RegisterRoutes() map[string]func(writer http.R } //查看某个图片缓存的详情。 -func (this *ImageCacheController) Detail(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *ImageCacheController) Detail(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") if uuid == "" { @@ -66,7 +67,7 @@ func (this *ImageCacheController) Detail(writer http.ResponseWriter, request *ht } //按照分页的方式获取某个图片缓存夹下图片缓存和子图片缓存夹的列表,通常情况下只有一页。 -func (this *ImageCacheController) Page(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *ImageCacheController) Page(writer http.ResponseWriter, request *http.Request) *result.WebResult { //如果是根目录,那么就传入root. pageStr := request.FormValue("page") pageSizeStr := request.FormValue("pageSize") @@ -122,7 +123,7 @@ func (this *ImageCacheController) Page(writer http.ResponseWriter, request *http } //删除一个图片缓存 -func (this *ImageCacheController) Delete(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *ImageCacheController) Delete(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") if uuid == "" { @@ -143,7 +144,7 @@ func (this *ImageCacheController) Delete(writer http.ResponseWriter, request *ht } //删除一系列图片缓存。 -func (this *ImageCacheController) DeleteBatch(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *ImageCacheController) DeleteBatch(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuids := request.FormValue("uuids") if uuids == "" { diff --git a/rest/image_cache_service.go b/rest/image_cache_service.go index 018d113..2706bc1 100644 --- a/rest/image_cache_service.go +++ b/rest/image_cache_service.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strconv" "strings" + "tank/rest/util" ) //@Service @@ -179,7 +180,7 @@ func (this *ImageCacheService) ResizeImage(request *http.Request, filePath strin func (this *ImageCacheService) cacheImage(writer http.ResponseWriter, request *http.Request, matter *Matter) *ImageCache { //当前的文件是否是图片,只有图片才能处理。 - extension := GetExtension(matter.Name) + extension := util.GetExtension(matter.Name) formats := map[string]imaging.Format{ ".jpg": imaging.JPEG, ".jpeg": imaging.JPEG, @@ -203,9 +204,9 @@ func (this *ImageCacheService) cacheImage(writer http.ResponseWriter, request *h //resize图片 dstImage := this.ResizeImage(request, matter.AbsolutePath()) - cacheImageName := GetSimpleFileName(matter.Name) + "_" + mode + extension - cacheImageRelativePath := GetSimpleFileName(matter.Path) + "_" + mode + extension - cacheImageAbsolutePath := GetUserCacheRootDir(user.Username) + GetSimpleFileName(matter.Path) + "_" + mode + extension + cacheImageName := util.GetSimpleFileName(matter.Name) + "_" + mode + extension + cacheImageRelativePath := util.GetSimpleFileName(matter.Path) + "_" + mode + extension + cacheImageAbsolutePath := GetUserCacheRootDir(user.Username) + util.GetSimpleFileName(matter.Path) + "_" + mode + extension //创建目录。 dir := filepath.Dir(cacheImageAbsolutePath) diff --git a/rest/install_controller.go b/rest/install_controller.go index a52377d..0b2c012 100644 --- a/rest/install_controller.go +++ b/rest/install_controller.go @@ -11,6 +11,7 @@ import ( "os" "regexp" "strconv" + "tank/rest/result" "time" ) @@ -226,7 +227,7 @@ func (this *InstallController) validateTableMetaList(tableInfoList []*InstallTab } //验证数据库连接 -func (this *InstallController) Verify(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *InstallController) Verify(writer http.ResponseWriter, request *http.Request) *result.WebResult { db := this.openDbConnection(writer, request) defer this.closeDbConnection(db) @@ -239,7 +240,7 @@ func (this *InstallController) Verify(writer http.ResponseWriter, request *http. } //获取需要安装的数据库表 -func (this *InstallController) TableInfoList(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *InstallController) TableInfoList(writer http.ResponseWriter, request *http.Request) *result.WebResult { db := this.openDbConnection(writer, request) defer this.closeDbConnection(db) @@ -248,7 +249,7 @@ func (this *InstallController) TableInfoList(writer http.ResponseWriter, request } //创建缺失数据库和表 -func (this *InstallController) CreateTable(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *InstallController) CreateTable(writer http.ResponseWriter, request *http.Request) *result.WebResult { var tableNames = []IBase{&Dashboard{}, &DownloadToken{}, &Footprint{}, &ImageCache{}, &Matter{}, &Preference{}, &Session{}, &UploadToken{}, &User{}} var installTableInfos []*InstallTableInfo @@ -277,7 +278,7 @@ func (this *InstallController) CreateTable(writer http.ResponseWriter, request * } //获取管理员列表(10条记录) -func (this *InstallController) AdminList(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *InstallController) AdminList(writer http.ResponseWriter, request *http.Request) *result.WebResult { db := this.openDbConnection(writer, request) defer this.closeDbConnection(db) @@ -295,7 +296,7 @@ func (this *InstallController) AdminList(writer http.ResponseWriter, request *ht } //创建管理员 -func (this *InstallController) CreateAdmin(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *InstallController) CreateAdmin(writer http.ResponseWriter, request *http.Request) *result.WebResult { db := this.openDbConnection(writer, request) defer this.closeDbConnection(db) @@ -356,7 +357,7 @@ func (this *InstallController) CreateAdmin(writer http.ResponseWriter, request * } //(如果数据库中本身存在管理员了)验证管理员 -func (this *InstallController) ValidateAdmin(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *InstallController) ValidateAdmin(writer http.ResponseWriter, request *http.Request) *result.WebResult { db := this.openDbConnection(writer, request) defer this.closeDbConnection(db) @@ -391,7 +392,7 @@ func (this *InstallController) ValidateAdmin(writer http.ResponseWriter, request } //完成系统安装 -func (this *InstallController) Finish(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *InstallController) Finish(writer http.ResponseWriter, request *http.Request) *result.WebResult { mysqlPortStr := request.FormValue("mysqlPort") mysqlHost := request.FormValue("mysqlHost") diff --git a/rest/matter_controller.go b/rest/matter_controller.go index c38378f..99397f9 100644 --- a/rest/matter_controller.go +++ b/rest/matter_controller.go @@ -4,6 +4,7 @@ import ( "net/http" "strconv" "strings" + "tank/rest/result" ) type MatterController struct { @@ -67,7 +68,7 @@ func (this *MatterController) RegisterRoutes() map[string]func(writer http.Respo } //查看某个文件的详情。 -func (this *MatterController) Detail(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) Detail(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") if uuid == "" { @@ -89,7 +90,7 @@ func (this *MatterController) Detail(writer http.ResponseWriter, request *http.R } //创建一个文件夹。 -func (this *MatterController) CreateDirectory(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) CreateDirectory(writer http.ResponseWriter, request *http.Request) *result.WebResult { puuid := request.FormValue("puuid") @@ -108,7 +109,7 @@ func (this *MatterController) CreateDirectory(writer http.ResponseWriter, reques } //按照分页的方式获取某个文件夹下文件和子文件夹的列表,通常情况下只有一页。 -func (this *MatterController) Page(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) Page(writer http.ResponseWriter, request *http.Request) *result.WebResult { //如果是根目录,那么就传入root. pageStr := request.FormValue("page") @@ -189,7 +190,7 @@ func (this *MatterController) Page(writer http.ResponseWriter, request *http.Req } //上传文件 -func (this *MatterController) Upload(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) Upload(writer http.ResponseWriter, request *http.Request) *result.WebResult { userUuid := request.FormValue("userUuid") puuid := request.FormValue("puuid") @@ -230,7 +231,7 @@ func (this *MatterController) Upload(writer http.ResponseWriter, request *http.R } //从一个Url中去爬取资源 -func (this *MatterController) Crawl(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) Crawl(writer http.ResponseWriter, request *http.Request) *result.WebResult { userUuid := request.FormValue("userUuid") user := this.checkUser(writer, request) @@ -280,7 +281,7 @@ func (this *MatterController) Crawl(writer http.ResponseWriter, request *http.Re } //删除一个文件 -func (this *MatterController) Delete(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) Delete(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") if uuid == "" { @@ -301,7 +302,7 @@ func (this *MatterController) Delete(writer http.ResponseWriter, request *http.R } //删除一系列文件。 -func (this *MatterController) DeleteBatch(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) DeleteBatch(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuids := request.FormValue("uuids") if uuids == "" { @@ -335,7 +336,7 @@ func (this *MatterController) DeleteBatch(writer http.ResponseWriter, request *h //重命名一个文件或一个文件夹 -func (this *MatterController) Rename(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) Rename(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") name := request.FormValue("name") @@ -356,7 +357,7 @@ func (this *MatterController) Rename(writer http.ResponseWriter, request *http.R } //改变一个文件的公私有属性 -func (this *MatterController) ChangePrivacy(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) ChangePrivacy(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") privacyStr := request.FormValue("privacy") privacy := false @@ -383,7 +384,7 @@ func (this *MatterController) ChangePrivacy(writer http.ResponseWriter, request } //将一个文件夹或者文件移入到另一个文件夹下。 -func (this *MatterController) Move(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *MatterController) Move(writer http.ResponseWriter, request *http.Request) *result.WebResult { srcUuidsStr := request.FormValue("srcUuids") destUuid := request.FormValue("destUuid") diff --git a/rest/matter_dao.go b/rest/matter_dao.go index 14412c9..c881cb5 100644 --- a/rest/matter_dao.go +++ b/rest/matter_dao.go @@ -4,6 +4,7 @@ import ( "github.com/jinzhu/gorm" "github.com/nu7hatch/gouuid" "os" + "tank/rest/result" "time" ) @@ -292,7 +293,7 @@ func (this *MatterDao) findByUserUuidAndPath(userUuid string, path string) *Matt db := CONTEXT.DB.Model(&Matter{}).Where(wp.Query, wp.Args...).First(matter) if db.Error != nil { - if db.Error.Error() == DB_ERROR_NOT_FOUND { + if db.Error.Error() == result.DB_ERROR_NOT_FOUND { return nil } else { this.PanicError(db.Error) diff --git a/rest/matter_model.go b/rest/matter_model.go index bc858bf..2c777cc 100644 --- a/rest/matter_model.go +++ b/rest/matter_model.go @@ -1,5 +1,7 @@ package rest +import "tank/rest/util" + const ( MATTER_ROOT = "root" MATTER_CACHE = "cache" @@ -36,7 +38,7 @@ func (this *Matter) AbsolutePath() string { // 获取该Matter的MimeType func (this *Matter) MimeType() string { - return GetMimeType(GetExtension(this.Name)) + return util.GetMimeType(util.GetExtension(this.Name)) } diff --git a/rest/matter_service.go b/rest/matter_service.go index 6f73cc2..df6854a 100644 --- a/rest/matter_service.go +++ b/rest/matter_service.go @@ -1,18 +1,13 @@ package rest import ( - "errors" "fmt" "io" - "mime/multipart" "net/http" - "net/textproto" - "net/url" "os" "regexp" - "strconv" "strings" - "time" + "tank/rest/download" ) //@Service @@ -20,6 +15,7 @@ type MatterService struct { Bean matterDao *MatterDao userDao *UserDao + userService *UserService imageCacheDao *ImageCacheDao imageCacheService *ImageCacheService } @@ -39,6 +35,11 @@ func (this *MatterService) Init() { this.userDao = b } + b = CONTEXT.GetBean(this.userService) + if b, ok := b.(*UserService); ok { + this.userService = b + } + b = CONTEXT.GetBean(this.imageCacheDao) if b, ok := b.(*ImageCacheDao); ok { this.imageCacheDao = b @@ -372,193 +373,6 @@ func (this *MatterService) Crawl(url string, filename string, user *User, puuid return matter } -// httpRange specifies the byte range to be sent to the client. -type httpRange struct { - start, length int64 -} - -func (r httpRange) contentRange(size int64) string { - return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) -} - -func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { - return textproto.MIMEHeader{ - "Content-Range": {r.contentRange(size)}, - "Content-Type": {contentType}, - } -} - -// countingWriter counts how many bytes have been written to it. -type countingWriter int64 - -func (w *countingWriter) Write(p []byte) (n int, err error) { - *w += countingWriter(len(p)) - return len(p), nil -} - -//检查Last-Modified头。返回true: 请求已经完成了。(言下之意,文件没有修改过) 返回false:文件修改过。 -func (this *MatterService) checkLastModified(w http.ResponseWriter, r *http.Request, modifyTime time.Time) bool { - if modifyTime.IsZero() { - return false - } - - // The Date-Modified header truncates sub-second precision, so - // use mtime < t+1s instead of mtime <= t to check for unmodified. - if t, err := time.Parse(http.TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modifyTime.Before(t.Add(1*time.Second)) { - h := w.Header() - delete(h, "Content-Type") - delete(h, "Content-Length") - w.WriteHeader(http.StatusNotModified) - return true - } - w.Header().Set("Last-Modified", modifyTime.UTC().Format(http.TimeFormat)) - return false -} - -// 处理ETag标签 -// checkETag implements If-None-Match and If-Range checks. -// -// The ETag or modtime must have been previously set in the -// ResponseWriter's headers. The modtime is only compared at second -// granularity and may be the zero value to mean unknown. -// -// The return value is the effective request "Range" header to use and -// whether this request is now considered done. -func (this *MatterService) checkETag(w http.ResponseWriter, r *http.Request, modtime time.Time) (rangeReq string, done bool) { - etag := w.Header().Get("Etag") - rangeReq = r.Header.Get("Range") - - // Invalidate the range request if the entity doesn't match the one - // the client was expecting. - // "If-Range: version" means "ignore the Range: header unless version matches the - // current file." - // We only support ETag versions. - // The caller must have set the ETag on the response already. - if ir := r.Header.Get("If-Range"); ir != "" && ir != etag { - // The If-Range value is typically the ETag value, but it may also be - // the modtime date. See golang.org/issue/8367. - timeMatches := false - if !modtime.IsZero() { - if t, err := http.ParseTime(ir); err == nil && t.Unix() == modtime.Unix() { - timeMatches = true - } - } - if !timeMatches { - rangeReq = "" - } - } - - if inm := r.Header.Get("If-None-Match"); inm != "" { - // Must know ETag. - if etag == "" { - return rangeReq, false - } - - // (bradfitz): non-GET/HEAD requests require more work: - // sending a different status code on matches, and - // also can't use weak cache validators (those with a "W/ - // prefix). But most users of ServeContent will be using - // it on GET or HEAD, so only support those for now. - if r.Method != "GET" && r.Method != "HEAD" { - return rangeReq, false - } - - // (bradfitz): deal with comma-separated or multiple-valued - // list of If-None-match values. For now just handle the common - // case of a single item. - if inm == etag || inm == "*" { - h := w.Header() - delete(h, "Content-Type") - delete(h, "Content-Length") - w.WriteHeader(http.StatusNotModified) - return "", true - } - } - return rangeReq, false -} - -// parseRange parses a Range header string as per RFC 2616. -func (this *MatterService) parseRange(s string, size int64) ([]httpRange, error) { - if s == "" { - return nil, nil // header not present - } - const b = "bytes=" - if !strings.HasPrefix(s, b) { - return nil, errors.New("invalid range") - } - var ranges []httpRange - for _, ra := range strings.Split(s[len(b):], ",") { - ra = strings.TrimSpace(ra) - if ra == "" { - continue - } - i := strings.Index(ra, "-") - if i < 0 { - return nil, errors.New("invalid range") - } - start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) - var r httpRange - if start == "" { - // If no start is specified, end specifies the - // range start relative to the end of the file. - i, err := strconv.ParseInt(end, 10, 64) - if err != nil { - return nil, errors.New("invalid range") - } - if i > size { - i = size - } - r.start = size - i - r.length = size - r.start - } else { - i, err := strconv.ParseInt(start, 10, 64) - if err != nil || i >= size || i < 0 { - return nil, errors.New("invalid range") - } - r.start = i - if end == "" { - // If no end is specified, range extends to end of the file. - r.length = size - r.start - } else { - i, err := strconv.ParseInt(end, 10, 64) - if err != nil || r.start > i { - return nil, errors.New("invalid range") - } - if i >= size { - i = size - 1 - } - r.length = i - r.start + 1 - } - } - ranges = append(ranges, r) - } - return ranges, nil -} - -// rangesMIMESize returns the number of bytes it takes to encode the -// provided ranges as a multipart response. -func (this *MatterService) rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { - var w countingWriter - mw := multipart.NewWriter(&w) - for _, ra := range ranges { - _, e := mw.CreatePart(ra.mimeHeader(contentType, contentSize)) - this.PanicError(e) - - encSize += ra.length - } - e := mw.Close() - this.PanicError(e) - encSize += int64(w) - return -} - -func (this *MatterService) sumRangesSize(ranges []httpRange) (size int64) { - for _, ra := range ranges { - size += ra.length - } - return -} - //文件下载。具有进度功能。 //下载功能参考:https://github.com/Masterminds/go-fileserver func (this *MatterService) DownloadFile( @@ -568,143 +382,7 @@ func (this *MatterService) DownloadFile( filename string, withContentDisposition bool) { - diskFile, err := os.Open(filePath) - this.PanicError(err) - defer func() { - e := diskFile.Close() - this.PanicError(e) - }() - - //根据参数添加content-disposition。该Header会让浏览器自动下载,而不是预览。 - if withContentDisposition { - fileName := url.QueryEscape(filename) - writer.Header().Set("content-disposition", "attachment; filename=\""+fileName+"\"") - } - - //显示文件大小。 - fileInfo, err := diskFile.Stat() - if err != nil { - this.PanicServer("无法从磁盘中获取文件信息") - } - - modifyTime := fileInfo.ModTime() - - if this.checkLastModified(writer, request, modifyTime) { - return - } - rangeReq, done := this.checkETag(writer, request, modifyTime) - if done { - return - } - - code := http.StatusOK - - // From net/http/sniff.go - // The algorithm uses at most sniffLen bytes to make its decision. - const sniffLen = 512 - - // If Content-Type isn't set, use the file's extension to find it, but - // if the Content-Type is unset explicitly, do not sniff the type. - ctypes, haveType := writer.Header()["Content-Type"] - var ctype string - if !haveType { - //放弃原有的判断mime的方法 - //ctype = mime.TypeByExtension(filepath.Ext(fileInfo.Name())) - //使用mimeUtil来获取mime - ctype = GetFallbackMimeType(filename, "") - if ctype == "" { - // read a chunk to decide between utf-8 text and binary - var buf [sniffLen]byte - n, _ := io.ReadFull(diskFile, buf[:]) - ctype = http.DetectContentType(buf[:n]) - _, err := diskFile.Seek(0, os.SEEK_SET) // rewind to output whole file - if err != nil { - this.PanicServer("无法准确定位文件") - } - } - writer.Header().Set("Content-Type", ctype) - } else if len(ctypes) > 0 { - ctype = ctypes[0] - } - - size := fileInfo.Size() - - // handle Content-Range header. - sendSize := size - var sendContent io.Reader = diskFile - if size >= 0 { - ranges, err := this.parseRange(rangeReq, size) - if err != nil { - panic(CustomWebResult(CODE_WRAPPER_RANGE_NOT_SATISFIABLE, "range header出错")) - } - if this.sumRangesSize(ranges) > size { - // The total number of bytes in all the ranges - // is larger than the size of the file by - // itself, so this is probably an attack, or a - // dumb client. Ignore the range request. - ranges = nil - } - switch { - case len(ranges) == 1: - // RFC 2616, Section 14.16: - // "When an HTTP message includes the content of a single - // range (for example, a response to a request for a - // single range, or to a request for a set of ranges - // that overlap without any holes), this content is - // transmitted with a Content-Range header, and a - // Content-Length header showing the number of bytes - // actually transferred. - // ... - // A response to a request for a single range MUST NOT - // be sent using the multipart/byteranges media type." - ra := ranges[0] - if _, err := diskFile.Seek(ra.start, io.SeekStart); err != nil { - panic(CustomWebResult(CODE_WRAPPER_RANGE_NOT_SATISFIABLE, "range header出错")) - } - sendSize = ra.length - code = http.StatusPartialContent - writer.Header().Set("Content-Range", ra.contentRange(size)) - case len(ranges) > 1: - sendSize = this.rangesMIMESize(ranges, ctype, size) - code = http.StatusPartialContent - - pr, pw := io.Pipe() - mw := multipart.NewWriter(pw) - writer.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) - sendContent = pr - defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. - go func() { - for _, ra := range ranges { - part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) - if err != nil { - pw.CloseWithError(err) - return - } - if _, err := diskFile.Seek(ra.start, io.SeekStart); err != nil { - pw.CloseWithError(err) - return - } - if _, err := io.CopyN(part, diskFile, ra.length); err != nil { - pw.CloseWithError(err) - return - } - } - mw.Close() - pw.Close() - }() - } - - writer.Header().Set("Accept-Ranges", "bytes") - if writer.Header().Get("Content-Encoding") == "" { - writer.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) - } - } - - writer.WriteHeader(code) - - if request.Method != "HEAD" { - io.CopyN(writer, sendContent, sendSize) - } + download.DownloadFile(writer, request, filePath, filename, withContentDisposition) } diff --git a/rest/preference_controller.go b/rest/preference_controller.go index 22f10b3..2944c8b 100644 --- a/rest/preference_controller.go +++ b/rest/preference_controller.go @@ -2,6 +2,7 @@ package rest import ( "net/http" + "tank/rest/result" ) type PreferenceController struct { @@ -41,7 +42,7 @@ func (this *PreferenceController) RegisterRoutes() map[string]func(writer http.R } //查看某个偏好设置的详情。 -func (this *PreferenceController) Fetch(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *PreferenceController) Fetch(writer http.ResponseWriter, request *http.Request) *result.WebResult { preference := this.preferenceService.Fetch() @@ -49,7 +50,7 @@ func (this *PreferenceController) Fetch(writer http.ResponseWriter, request *htt } //修改 -func (this *PreferenceController) Edit(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *PreferenceController) Edit(writer http.ResponseWriter, request *http.Request) *result.WebResult { //验证参数。 name := request.FormValue("name") @@ -84,7 +85,7 @@ func (this *PreferenceController) Edit(writer http.ResponseWriter, request *http } //清扫系统,所有数据全部丢失。一定要非常慎点,非常慎点!只在系统初始化的时候点击! -func (this *PreferenceController) SystemCleanup(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *PreferenceController) SystemCleanup(writer http.ResponseWriter, request *http.Request) *result.WebResult { user := this.checkUser(writer, request) password := request.FormValue("password") diff --git a/rest/preference_dao.go b/rest/preference_dao.go index a992f96..3d61e75 100644 --- a/rest/preference_dao.go +++ b/rest/preference_dao.go @@ -2,6 +2,7 @@ package rest import ( "github.com/nu7hatch/gouuid" + "tank/rest/result" "time" ) @@ -17,7 +18,7 @@ func (this *PreferenceDao) Fetch() *Preference { db := CONTEXT.DB.First(preference) if db.Error != nil { - if db.Error.Error() == DB_ERROR_NOT_FOUND { + if db.Error.Error() == result.DB_ERROR_NOT_FOUND { preference.Name = "蓝眼云盘" preference.ShowAlien = true this.Create(preference) diff --git a/rest/web_result.go b/rest/result/web_result.go similarity index 99% rename from rest/web_result.go rename to rest/result/web_result.go index c0b1e73..494788e 100644 --- a/rest/web_result.go +++ b/rest/result/web_result.go @@ -1,4 +1,4 @@ -package rest +package result import "net/http" diff --git a/rest/router.go b/rest/router.go index c852627..e1b752f 100644 --- a/rest/router.go +++ b/rest/router.go @@ -7,6 +7,8 @@ import ( "net/http" "os" "strings" + "tank/rest/result" + "tank/rest/util" "time" ) @@ -70,26 +72,26 @@ func (this *Router) GlobalPanicHandler(writer http.ResponseWriter, request *http LOGGER.Error("错误: %v", err) - var webResult *WebResult = nil + var webResult *result.WebResult = nil if value, ok := err.(string); ok { //一个字符串,默认是请求错误。 - webResult = CustomWebResult(CODE_WRAPPER_BAD_REQUEST, value) - } else if value, ok := err.(*WebResult); ok { + webResult = result.CustomWebResult(result.CODE_WRAPPER_BAD_REQUEST, value) + } else if value, ok := err.(*result.WebResult); ok { //一个WebResult对象 webResult = value - } else if value, ok := err.(*CodeWrapper); ok { + } else if value, ok := err.(*result.CodeWrapper); ok { //一个WebResult对象 - webResult = ConstWebResult(value) + webResult = result.ConstWebResult(value) } else if value, ok := err.(error); ok { //一个普通的错误对象 - webResult = CustomWebResult(CODE_WRAPPER_UNKNOWN, value.Error()) + webResult = result.CustomWebResult(result.CODE_WRAPPER_UNKNOWN, value.Error()) } else { //其他不能识别的内容 - webResult = ConstWebResult(CODE_WRAPPER_UNKNOWN) + webResult = result.ConstWebResult(result.CODE_WRAPPER_UNKNOWN) } //修改http code码 - writer.WriteHeader(FetchHttpStatus(webResult.Code)) + writer.WriteHeader(result.FetchHttpStatus(webResult.Code)) //输出的是json格式 返回的内容申明是json,utf-8 writer.Header().Set("Content-Type", "application/json;charset=UTF-8") @@ -148,7 +150,7 @@ func (this *Router) ServeHTTP(writer http.ResponseWriter, request *http.Request) } if !canHandle { - panic(CustomWebResult(CODE_WRAPPER_NOT_FOUND, fmt.Sprintf("没有找到能够处理%s的方法", path))) + panic(result.CustomWebResult(result.CODE_WRAPPER_NOT_FOUND, fmt.Sprintf("没有找到能够处理%s的方法", path))) } } @@ -162,7 +164,7 @@ func (this *Router) ServeHTTP(writer http.ResponseWriter, request *http.Request) if handler, ok := this.installRouteMap[path]; ok { handler(writer, request) } else { - panic(ConstWebResult(CODE_WRAPPER_NOT_INSTALLED)) + panic(result.ConstWebResult(result.CODE_WRAPPER_NOT_INSTALLED)) } } @@ -185,7 +187,7 @@ func (this *Router) ServeHTTP(writer http.ResponseWriter, request *http.Request) } } - writer.Header().Set("Content-Type", GetMimeType(GetExtension(filePath))) + writer.Header().Set("Content-Type", util.GetMimeType(util.GetExtension(filePath))) diskFile, err := os.Open(filePath) if err != nil { diff --git a/rest/user_controller.go b/rest/user_controller.go index 2dcccb6..7477d7f 100644 --- a/rest/user_controller.go +++ b/rest/user_controller.go @@ -4,6 +4,7 @@ import ( "net/http" "regexp" "strconv" + "tank/rest/result" "time" ) @@ -41,7 +42,7 @@ func (this *UserController) RegisterRoutes() map[string]func(writer http.Respons //参数: // @email:邮箱 // @password:密码 -func (this *UserController) Login(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Login(writer http.ResponseWriter, request *http.Request) *result.WebResult { email := request.FormValue("email") password := request.FormValue("password") @@ -55,6 +56,7 @@ func (this *UserController) Login(writer http.ResponseWriter, request *http.Requ if user == nil { this.PanicBadRequest("邮箱或密码错误") + } else { if !MatchBcrypt(password, user.Password) { @@ -93,7 +95,7 @@ func (this *UserController) Login(writer http.ResponseWriter, request *http.Requ } //创建一个用户 -func (this *UserController) Create(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Create(writer http.ResponseWriter, request *http.Request) *result.WebResult { username := request.FormValue("username") if m, _ := regexp.MatchString(`^[0-9a-zA-Z_]+$`, username); !m { @@ -156,7 +158,7 @@ func (this *UserController) Create(writer http.ResponseWriter, request *http.Req } //编辑一个用户的资料。 -func (this *UserController) Edit(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Edit(writer http.ResponseWriter, request *http.Request) *result.WebResult { avatarUrl := request.FormValue("avatarUrl") uuid := request.FormValue("uuid") @@ -199,7 +201,7 @@ func (this *UserController) Edit(writer http.ResponseWriter, request *http.Reque } //获取用户详情 -func (this *UserController) Detail(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Detail(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") @@ -210,7 +212,7 @@ func (this *UserController) Detail(writer http.ResponseWriter, request *http.Req } //退出登录 -func (this *UserController) Logout(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Logout(writer http.ResponseWriter, request *http.Request) *result.WebResult { //session置为过期 sessionCookie, err := request.Cookie(COOKIE_AUTH_KEY) @@ -246,7 +248,7 @@ func (this *UserController) Logout(writer http.ResponseWriter, request *http.Req } //获取用户列表 管理员的权限。 -func (this *UserController) Page(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Page(writer http.ResponseWriter, request *http.Request) *result.WebResult { pageStr := request.FormValue("page") pageSizeStr := request.FormValue("pageSize") @@ -298,7 +300,7 @@ func (this *UserController) Page(writer http.ResponseWriter, request *http.Reque } //禁用用户 -func (this *UserController) Disable(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Disable(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") @@ -322,7 +324,7 @@ func (this *UserController) Disable(writer http.ResponseWriter, request *http.Re } //启用用户 -func (this *UserController) Enable(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) Enable(writer http.ResponseWriter, request *http.Request) *result.WebResult { uuid := request.FormValue("uuid") @@ -345,7 +347,7 @@ func (this *UserController) Enable(writer http.ResponseWriter, request *http.Req } //用户修改密码 -func (this *UserController) ChangePassword(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) ChangePassword(writer http.ResponseWriter, request *http.Request) *result.WebResult { oldPassword := request.FormValue("oldPassword") newPassword := request.FormValue("newPassword") @@ -372,7 +374,7 @@ func (this *UserController) ChangePassword(writer http.ResponseWriter, request * } //管理员重置用户密码 -func (this *UserController) ResetPassword(writer http.ResponseWriter, request *http.Request) *WebResult { +func (this *UserController) ResetPassword(writer http.ResponseWriter, request *http.Request) *result.WebResult { userUuid := request.FormValue("userUuid") password := request.FormValue("password") diff --git a/rest/user_service.go b/rest/user_service.go index 008f7fe..43e2cb9 100644 --- a/rest/user_service.go +++ b/rest/user_service.go @@ -10,6 +10,9 @@ type UserService struct { Bean userDao *UserDao sessionDao *SessionDao + + //操作文件的锁。 + locker *CacheTable } //初始化方法 @@ -27,8 +30,45 @@ func (this *UserService) Init() { this.sessionDao = b } + //创建一个用于存储用户文件锁的缓存。 + this.locker = NewCacheTable() } + +//对某个用户进行加锁。加锁阶段用户是不允许操作文件的。 +func (this *UserService) MatterLock(userUuid string) { + //如果已经是锁住的状态,直接报错 + + //去缓存中捞取 + cacheItem, err := this.locker.Value(userUuid) + if err != nil { + this.logger.Error("获取缓存时出错了" + err.Error()) + } + + //当前被锁住了。 + if cacheItem != nil && cacheItem.Data() != nil { + this.PanicBadRequest("当前正在进行文件操作,请稍后再试!") + } + + //添加一把新锁,有效期为12小时 + duration := 12 * time.Hour + this.locker.Add(userUuid, duration, true) +} + + +//对某个用户解锁,解锁后用户可以操作文件。 +func (this *UserService) MatterUnlock(userUuid string) { + + exist := this.locker.Exists(userUuid) + if exist { + _, err := this.locker.Delete(userUuid) + this.PanicError(err) + } else { + this.logger.Error("%s已经不存在matter锁了,解锁错误。", userUuid) + } +} + + //装载session信息,如果session没有了根据cookie去装填用户信息。 //在所有的路由最初会调用这个方法 func (this *UserService) bootstrap(writer http.ResponseWriter, request *http.Request) { diff --git a/rest/util_mime.go b/rest/util/util_mime.go similarity index 99% rename from rest/util_mime.go rename to rest/util/util_mime.go index 8091eb6..2f6ea7e 100644 --- a/rest/util_mime.go +++ b/rest/util/util_mime.go @@ -1,4 +1,4 @@ -package rest +package util import ( "os"