diff --git a/rest/base_controller.go b/rest/base_controller.go index e091cc5..75314cc 100644 --- a/rest/base_controller.go +++ b/rest/base_controller.go @@ -5,7 +5,6 @@ import ( "github.com/json-iterator/go" "go/types" "net/http" - "time" ) type IController interface { @@ -131,64 +130,47 @@ func (this *BaseController) Error(err interface{}) *WebResult { return webResult } -//能找到一个user就找到一个,遇到问题直接抛出错误 -func (this *BaseController) checkLogin(writer http.ResponseWriter, request *http.Request) (*Session, *User) { - - //验证用户是否已经登录。 - sessionCookie, err := request.Cookie(COOKIE_AUTH_KEY) - if err != nil { - panic(ConstWebResult(CODE_WRAPPER_LOGIN)) - } - - session := this.sessionDao.FindByUuid(sessionCookie.Value) - if session == nil { - panic(ConstWebResult(CODE_WRAPPER_LOGIN)) - } else { - if session.ExpireTime.Before(time.Now()) { - panic(ConstWebResult(CODE_WRAPPER_LOGIN_EXPIRE)) - } else { - - user := this.userDao.FindByUuid(session.UserUuid) - if user == nil { - panic(ConstWebResult(CODE_WRAPPER_LOGIN)) - } else { - return session, user - } - - } - } - -} - //能找到一个user就找到一个 func (this *BaseController) findUser(writer http.ResponseWriter, request *http.Request) *User { //验证用户是否已经登录。 sessionCookie, err := request.Cookie(COOKIE_AUTH_KEY) if err != nil { - LogError("找不到任何登录信息") + LogInfo("获取用户cookie时出错啦") return nil } - session := this.sessionDao.FindByUuid(sessionCookie.Value) - if session != nil { - if session.ExpireTime.Before(time.Now()) { - LogError("登录信息已过期") - return nil - } else { - user := this.userDao.FindByUuid(session.UserUuid) - if user != nil { - return user - } - } + sessionId := sessionCookie.Value + + LogInfo("findUser sessionId = " + sessionId) + + //去缓存中捞取看看 + cacheItem, err := CONTEXT.SessionCache.Value(sessionId) + if err != nil { + LogError("获取缓存时出错了" + err.Error()) + return nil + } + + if cacheItem.Data() == nil { + LogError("cache item中已经不存在了 " + err.Error()) + return nil + } + + if value, ok := cacheItem.Data().(*User); ok { + return value + } else { + LogError("cache item中的类型不是*User ") } return nil } func (this *BaseController) checkUser(writer http.ResponseWriter, request *http.Request) *User { - _, user := this.checkLogin(writer, request) - return user + if this.findUser(writer, request) == nil { + panic(ConstWebResult(CODE_WRAPPER_LOGIN)) + } else { + return this.findUser(writer, request) + } } //允许跨域请求 diff --git a/rest/context.go b/rest/context.go index f8d2419..229d100 100644 --- a/rest/context.go +++ b/rest/context.go @@ -25,6 +25,31 @@ type Context struct { Router *Router } + +//初始化上下文 +func (this *Context) Init() { + + //处理数据库连接的开关。 + this.OpenDb() + + //创建一个用于存储session的缓存。 + this.SessionCache = NewCacheTable() + + //初始化Map + this.BeanMap = make(map[string]IBean) + this.ControllerMap = make(map[string]IController) + + //注册各类Beans.在这个方法里面顺便把Controller装入ControllerMap中去。 + this.registerBeans() + + //初始化每个bean. + this.initBeans() + + //初始化Router. 这个方法要在Bean注册好了之后才能。 + this.Router = NewRouter() +} + + func (this *Context) OpenDb() { var err error = nil @@ -47,28 +72,6 @@ func (this *Context) CloseDb() { } } -//构造方法 -func (this *Context) Init() { - - //处理数据库连接的开关。 - this.OpenDb() - - //创建一个用于存储session的缓存。 - this.SessionCache = NewCacheTable() - - //初始化Map - this.BeanMap = make(map[string]IBean) - this.ControllerMap = make(map[string]IController) - - //注册各类Beans.在这个方法里面顺便把Controller装入ControllerMap中去。 - this.registerBeans() - - //初始化每个bean. - this.initBeans() - - //初始化Router. 这个方法要在Bean注册好了之后才能。 - this.Router = NewRouter() -} //注册一个Bean func (this *Context) registerBean(bean IBean) { diff --git a/rest/router.go b/rest/router.go index 9f7baeb..24b3f52 100644 --- a/rest/router.go +++ b/rest/router.go @@ -1,7 +1,6 @@ package rest import ( - "encoding/json" "fmt" "github.com/json-iterator/go" "io" @@ -12,8 +11,9 @@ import ( //用于处理所有前来的请求 type Router struct { - userService *UserService - routeMap map[string]func(writer http.ResponseWriter, request *http.Request) + securityVisitService *SecurityVisitService + userService *UserService + routeMap map[string]func(writer http.ResponseWriter, request *http.Request) } //构造方法 @@ -22,13 +22,18 @@ func NewRouter() *Router { routeMap: make(map[string]func(writer http.ResponseWriter, request *http.Request)), } - //装载userService. b := CONTEXT.GetBean(router.userService) if b, ok := b.(*UserService); ok { router.userService = b } + //装载securityVisitService + b = CONTEXT.GetBean(router.securityVisitService) + if b, ok := b.(*SecurityVisitService); ok { + router.securityVisitService = b + } + //将Controller中的路由规则装载机进来 for _, controller := range CONTEXT.ControllerMap { routes := controller.RegisterRoutes() @@ -80,62 +85,13 @@ func (this *Router) GlobalPanicHandler(writer http.ResponseWriter, request *http var json = jsoniter.ConfigCompatibleWithStandardLibrary b, _ := json.Marshal(webResult) - fmt.Fprintf(writer, string(b)) + _, err := fmt.Fprintf(writer, string(b)) + if err != nil { + fmt.Printf("输出结果时出错了\n") + } } } -//记录访问记录 -func (this *Router) logSecurityVisit(writer http.ResponseWriter, request *http.Request) { - //手动装填本实例的Bean. 这里必须要用中间变量方可。 - var securityVisitDao *SecurityVisitDao - b := CONTEXT.GetBean(securityVisitDao) - if b, ok := b.(*SecurityVisitDao); ok { - securityVisitDao = b - } - - fmt.Printf("Host = %s Uri = %s Path = %s RawPath = %s RawQuery = %s \n", - request.Host, - request.RequestURI, - request.URL.Path, - request.URL.RawPath, - request.URL.RawQuery) - - params := make(map[string][]string) - - //POST请求参数 - values := request.PostForm - for key, val := range values { - params[key] = val - } - //GET请求参数 - values1 := request.URL.Query() - for key, val := range values1 { - params[key] = val - } - - //用json的方式输出返回值。 - paramsString := "{}" - paramsData, err := json.Marshal(params) - if err == nil { - paramsString = string(paramsData) - } - - //将文件信息存入数据库中。 - securityVisit := &SecurityVisit{ - SessionId: "", - UserUuid: "testUserUUid", - Ip: GetIpAddress(request), - Host: request.Host, - Uri: request.URL.Path, - Params: paramsString, - Cost: 0, - Success: true, - } - - securityVisit = securityVisitDao.Create(securityVisit) - -} - //让Router具有处理请求的功能。 func (this *Router) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -147,14 +103,14 @@ func (this *Router) ServeHTTP(writer http.ResponseWriter, request *http.Request) if strings.HasPrefix(path, "/api") { //统一处理用户的身份信息。 - this.userService.enter(writer, request) + this.userService.bootstrap(writer, request) if handler, ok := this.routeMap[path]; ok { handler(writer, request) } else { - //直接将请求扔给每个controller,看看他们能不能处理,如果都不能处理,那就算了。 + //直接将请求扔给每个controller,看看他们能不能处理,如果都不能处理,那就抛出找不到的错误 canHandle := false for _, controller := range CONTEXT.ControllerMap { if handler, exist := controller.HandleRoutes(writer, request); exist { @@ -172,7 +128,7 @@ func (this *Router) ServeHTTP(writer http.ResponseWriter, request *http.Request) } //正常的访问记录会落到这里。 - go this.logSecurityVisit(writer, request) + go this.securityVisitService.Log(writer, request) } else { //当作静态资源处理。默认从当前文件下面的static文件夹中取东西。 diff --git a/rest/security_visit_service.go b/rest/security_visit_service.go index c547f48..866c315 100644 --- a/rest/security_visit_service.go +++ b/rest/security_visit_service.go @@ -1,5 +1,11 @@ package rest +import ( + "encoding/json" + "fmt" + "net/http" +) + //@Service type SecurityVisitService struct { Bean @@ -30,3 +36,57 @@ func (this *SecurityVisitService) Detail(uuid string) *SecurityVisit { return securityVisit } + + + +//记录访问记录 +func (this *SecurityVisitService) Log(writer http.ResponseWriter, request *http.Request) { + //手动装填本实例的Bean. 这里必须要用中间变量方可。 + var securityVisitDao *SecurityVisitDao + b := CONTEXT.GetBean(securityVisitDao) + if b, ok := b.(*SecurityVisitDao); ok { + securityVisitDao = b + } + + fmt.Printf("Host = %s Uri = %s Path = %s RawPath = %s RawQuery = %s \n", + request.Host, + request.RequestURI, + request.URL.Path, + request.URL.RawPath, + request.URL.RawQuery) + + params := make(map[string][]string) + + //POST请求参数 + values := request.PostForm + for key, val := range values { + params[key] = val + } + //GET请求参数 + values1 := request.URL.Query() + for key, val := range values1 { + params[key] = val + } + + //用json的方式输出返回值。 + paramsString := "{}" + paramsData, err := json.Marshal(params) + if err == nil { + paramsString = string(paramsData) + } + + //将文件信息存入数据库中。 + securityVisit := &SecurityVisit{ + SessionId: "", + UserUuid: "testUserUUid", + Ip: GetIpAddress(request), + Host: request.Host, + Uri: request.URL.Path, + Params: paramsString, + Cost: 0, + Success: true, + } + + securityVisit = securityVisitDao.Create(securityVisit) + +} diff --git a/rest/session_dao.go b/rest/session_dao.go index 92a4f2e..15a00c0 100644 --- a/rest/session_dao.go +++ b/rest/session_dao.go @@ -40,18 +40,6 @@ func (this *SessionDao) CheckByUuid(uuid string) *Session { return session } -//按照authentication查询用户。 -func (this *SessionDao) FindByAuthentication(authentication string) *Session { - - var session = &Session{} - db := CONTEXT.DB.Where(&Session{Authentication: authentication}).First(session) - if db.Error != nil { - return nil - } - return session - -} - //创建一个session并且持久化到数据库中。 func (this *SessionDao) Create(session *Session) *Session { @@ -63,6 +51,18 @@ func (this *SessionDao) Create(session *Session) *Session { return session } + +//修改一个session +func (this *SessionDao) Save(session *Session) *Session { + + session.UpdateTime = time.Now() + db := CONTEXT.DB.Save(session) + this.PanicError(db.Error) + + return session +} + + func (this *SessionDao) Delete(uuid string) { session := this.CheckByUuid(uuid) @@ -73,3 +73,5 @@ func (this *SessionDao) Delete(uuid string) { this.PanicError(db.Error) } + + diff --git a/rest/session_model.go b/rest/session_model.go index b10f0b3..8d9ba1c 100644 --- a/rest/session_model.go +++ b/rest/session_model.go @@ -6,7 +6,6 @@ import ( type Session struct { Base - Authentication string `json:"authentication"` UserUuid string `json:"userUuid"` Ip string `json:"ip"` ExpireTime time.Time `json:"expireTime"` diff --git a/rest/user_controller.go b/rest/user_controller.go index d2101d0..1fd0c57 100644 --- a/rest/user_controller.go +++ b/rest/user_controller.go @@ -28,7 +28,7 @@ func (this *UserController) RegisterRoutes() map[string]func(writer http.Respons routeMap["/api/user/change/password"] = this.Wrap(this.ChangePassword, USER_ROLE_USER) routeMap["/api/user/reset/password"] = this.Wrap(this.ResetPassword, USER_ROLE_ADMINISTRATOR) routeMap["/api/user/login"] = this.Wrap(this.Login, USER_ROLE_GUEST) - routeMap["/api/user/logout"] = this.Wrap(this.Logout, USER_ROLE_USER) + routeMap["/api/user/logout"] = this.Wrap(this.Logout, USER_ROLE_GUEST) routeMap["/api/user/detail"] = this.Wrap(this.Detail, USER_ROLE_USER) routeMap["/api/user/page"] = this.Wrap(this.Page, USER_ROLE_ADMINISTRATOR) routeMap["/api/user/disable"] = this.Wrap(this.Disable, USER_ROLE_ADMINISTRATOR) @@ -212,10 +212,26 @@ func (this *UserController) Detail(writer http.ResponseWriter, request *http.Req //退出登录 func (this *UserController) Logout(writer http.ResponseWriter, request *http.Request) *WebResult { - session, _ := this.checkLogin(writer, request) + //session置为过期 + sessionCookie, err := request.Cookie(COOKIE_AUTH_KEY) + if err != nil { + LogError("找不到任何登录信息") + return this.Success("已经退出登录了!") + } + sessionId := sessionCookie.Value - //删除session - this.sessionDao.Delete(session.Uuid) + user := this.findUser(writer, request) + if user != nil { + session := this.sessionDao.FindByUuid(sessionId) + session.ExpireTime = time.Now() + this.sessionDao.Save(session) + } + + //删掉session缓存 + _, err = CONTEXT.SessionCache.Delete(sessionId) + if err != nil { + LogError("删除用户session缓存时出错") + } //清空客户端的cookie. expiration := time.Now() @@ -223,7 +239,7 @@ func (this *UserController) Logout(writer http.ResponseWriter, request *http.Req cookie := http.Cookie{ Name: COOKIE_AUTH_KEY, Path: "/", - Value: session.Uuid, + Value: sessionId, Expires: expiration} http.SetCookie(writer, &cookie) diff --git a/rest/user_service.go b/rest/user_service.go index 91ee11a..cb1ba72 100644 --- a/rest/user_service.go +++ b/rest/user_service.go @@ -2,12 +2,14 @@ package rest import ( "net/http" + "time" ) //@Service type UserService struct { Bean - userDao *UserDao + userDao *UserDao + sessionDao *SessionDao } //初始化方法 @@ -19,11 +21,54 @@ func (this *UserService) Init() { this.userDao = b } + b = CONTEXT.GetBean(this.sessionDao) + if b, ok := b.(*SessionDao); ok { + this.sessionDao = b + } + } //装载session信息,如果session没有了根据cookie去装填用户信息。 //在所有的路由最初会调用这个方法 -func (this *UserService) enter(writer http.ResponseWriter, request *http.Request) { +func (this *UserService) bootstrap(writer http.ResponseWriter, request *http.Request) { + //登录身份有效期以数据库中记录的为准 + + //验证用户是否已经登录。 + sessionCookie, err := request.Cookie(COOKIE_AUTH_KEY) + if err != nil { + LogError("找不到任何登录信息") + return + } + + sessionId := sessionCookie.Value + + LogInfo("请求的sessionId = " + sessionId) + + //去缓存中捞取 + cacheItem, err := CONTEXT.SessionCache.Value(sessionId) + if err != nil { + LogError("获取缓存时出错了" + err.Error()) + } + + //缓存中没有,尝试去数据库捞取 + if cacheItem == nil || cacheItem.Data() == nil { + session := this.sessionDao.FindByUuid(sessionCookie.Value) + if session != nil { + duration := session.ExpireTime.Sub(time.Now()) + if duration <= 0 { + LogError("登录信息已过期") + } else { + user := this.userDao.FindByUuid(session.UserUuid) + if user != nil { + //将用户装填进缓存中 + CONTEXT.SessionCache.Add(sessionCookie.Value, duration, user) + + } else { + LogError("没有找到对应的user " + session.UserUuid) + } + } + } + } } diff --git a/rest/util_cache.go b/rest/util_cache.go index 0ecc617..0457d6c 100644 --- a/rest/util_cache.go +++ b/rest/util_cache.go @@ -334,7 +334,8 @@ func (table *CacheTable) Value(key interface{}, args ...interface{}) (*CacheItem return nil, errors.New("无法加载到缓存值") } - return nil, errors.New(fmt.Sprintf("没有找到%s对应的记录", key)) + //没有找到任何东西,返回nil. + return nil, nil } // 删除缓存表中的所有项目