package http import ( "context" "encoding/json" "net/http" "strings" "time" "git.kingecg.top/kingecg/gomog/internal/engine" "git.kingecg.top/kingecg/gomog/pkg/types" ) // HTTPServer HTTP 服务器 type HTTPServer struct { mux *http.ServeMux handler *RequestHandler server *http.Server } // RequestHandler 请求处理器 type RequestHandler struct { store *engine.MemoryStore crud *engine.CRUDHandler agg *engine.AggregationEngine } // NewRequestHandler 创建请求处理器 func NewRequestHandler(store *engine.MemoryStore, crud *engine.CRUDHandler, agg *engine.AggregationEngine) *RequestHandler { return &RequestHandler{ store: store, crud: crud, agg: agg, } } // NewHTTPServer 创建 HTTP 服务器 func NewHTTPServer(addr string, handler *RequestHandler) *HTTPServer { s := &HTTPServer{ mux: http.NewServeMux(), handler: handler, } // 注册路由 s.registerRoutes() s.server = &http.Server{ Addr: addr, Handler: s.mux, } return s } // Start 启动 HTTP 服务器 func (s *HTTPServer) Start() error { return s.server.ListenAndServe() } // Shutdown 关闭 HTTP 服务器 func (s *HTTPServer) Shutdown(ctx context.Context) error { return s.server.Shutdown(ctx) } // registerRoutes 注册路由 func (s *HTTPServer) registerRoutes() { // API v1 路由 s.mux.HandleFunc("/api/v1/", s.handleAPI) // 健康检查 s.mux.HandleFunc("/health", s.handleHealth) // 根路径 s.mux.HandleFunc("/", s.handleRoot) } // handleRoot 根路径处理 func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) { response := map[string]interface{}{ "name": "Gomog Server", "version": "1.0.0", "status": "running", } s.sendJSON(w, http.StatusOK, response) } // handleHealth 健康检查 func (s *HTTPServer) handleHealth(w http.ResponseWriter, r *http.Request) { response := map[string]interface{}{ "status": "healthy", } s.sendJSON(w, http.StatusOK, response) } // handleAPI 处理 API 请求 func (s *HTTPServer) handleAPI(w http.ResponseWriter, r *http.Request) { // 解析路径:/api/v1/{database}/{collection}/{operation} path := strings.TrimPrefix(r.URL.Path, "/api/v1/") parts := strings.Split(path, "/") if len(parts) < 3 { s.sendError(w, http.StatusBadRequest, "Invalid path. Expected: /api/v1/{database}/{collection}/{operation}") return } dbName := parts[0] collection := parts[1] operation := parts[2] // 确保集合已加载到内存 if err := s.loadCollectionIfNeeded(dbName, collection); err != nil { s.sendError(w, http.StatusInternalServerError, "Failed to load collection: "+err.Error()) return } // 根据操作类型分发请求 switch operation { case "find": s.handler.HandleFind(w, r, dbName, collection) case "insert": s.handler.HandleInsert(w, r, dbName, collection) case "update": s.handler.HandleUpdate(w, r, dbName, collection) case "delete": s.handler.HandleDelete(w, r, dbName, collection) case "aggregate": s.handler.HandleAggregate(w, r, dbName, collection) default: s.sendError(w, http.StatusBadRequest, "Unknown operation: "+operation) } } // loadCollectionIfNeeded 按需加载集合 func (s *HTTPServer) loadCollectionIfNeeded(dbName, collection string) error { // 简化处理:每次都尝试加载 // 实际应该检查是否已加载 fullCollection := dbName + "." + collection _, err := s.handler.store.GetCollection(fullCollection) if err == nil { return nil // 已加载 } // TODO: 从数据库加载集合 // return s.store.LoadCollection(context.Background(), fullCollection) return nil } // sendJSON 发送 JSON 响应 func (s *HTTPServer) sendJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } // sendError 发送错误响应 func (s *HTTPServer) sendError(w http.ResponseWriter, status int, message string) { s.sendJSON(w, status, map[string]interface{}{ "ok": 0, "error": message, "status": status, }) } // HandleFind 处理查询请求 func (h *RequestHandler) HandleFind(w http.ResponseWriter, r *http.Request, dbName, collection string) { if r.Method != http.MethodPost && r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req types.FindRequest if r.Method == http.MethodPost { if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } } // 执行查询 fullCollection := dbName + "." + collection docs, err := h.store.Find(fullCollection, req.Filter) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } // 应用排序 if req.Sort != nil && len(req.Sort) > 0 { // TODO: 实现排序逻辑 } // 应用跳过和限制 skip := req.Skip limit := req.Limit if skip > 0 && skip < len(docs) { docs = docs[skip:] } if limit > 0 && limit < len(docs) { docs = docs[:limit] } // 应用投影 if req.Projection != nil && len(req.Projection) > 0 { docs = applyProjection(docs, req.Projection) } response := types.Response{ OK: 1, Cursor: &types.Cursor{ FirstBatch: docs, ID: 0, NS: dbName + "." + collection, }, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // HandleInsert 处理插入请求 func (h *RequestHandler) HandleInsert(w http.ResponseWriter, r *http.Request, dbName, collection string) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req types.InsertRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } fullCollection := dbName + "." + collection insertedIDs := make(map[int]string) for i, docData := range req.Documents { // 生成 ID id := generateID() doc := types.Document{ ID: id, Data: docData, CreatedAt: time.Now(), UpdatedAt: time.Now(), } // 插入到内存 if err := h.store.Insert(fullCollection, doc); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } insertedIDs[i] = id } response := types.InsertResult{ OK: 1, N: len(req.Documents), InsertedIDs: insertedIDs, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // HandleUpdate 处理更新请求 func (h *RequestHandler) HandleUpdate(w http.ResponseWriter, r *http.Request, dbName, collection string) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req types.UpdateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } fullCollection := dbName + "." + collection totalMatched := 0 totalModified := 0 upserted := make([]types.UpsertID, 0) for _, op := range req.Updates { matched, modified, upsertedIDs, err := h.store.Update(fullCollection, op.Q, op.U, op.Upsert, op.ArrayFilters) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } totalMatched += matched totalModified += modified // 收集 upserted IDs for _, id := range upsertedIDs { upserted = append(upserted, types.UpsertID{ Index: 0, ID: id, }) } } response := types.UpdateResult{ OK: 1, N: totalMatched, NModified: totalModified, Upserted: upserted, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // HandleDelete 处理删除请求 func (h *RequestHandler) HandleDelete(w http.ResponseWriter, r *http.Request, dbName, collection string) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req types.DeleteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } fullCollection := dbName + "." + collection totalDeleted := 0 for _, op := range req.Deletes { deleted, err := h.store.Delete(fullCollection, op.Q) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } totalDeleted += deleted // 如果 limit=1,只删除第一个匹配的文档 if op.Limit == 1 && deleted > 0 { break } } response := types.DeleteResult{ OK: 1, N: totalDeleted, DeletedCount: totalDeleted, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // HandleAggregate 处理聚合请求 func (h *RequestHandler) HandleAggregate(w http.ResponseWriter, r *http.Request, dbName, collection string) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req types.AggregateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } fullCollection := dbName + "." + collection results, err := h.agg.Execute(fullCollection, req.Pipeline) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } response := types.AggregateResult{ OK: 1, Result: results, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // applyProjection 应用投影 func applyProjection(docs []types.Document, projection types.Projection) []types.Document { result := make([]types.Document, len(docs)) for i, doc := range docs { projected := make(map[string]interface{}) // 简单实现:只包含指定的字段 for field, include := range projection { if isTrue(include) && field != "_id" { projected[field] = getNestedValue(doc.Data, field) } } // 总是包含 _id 除非明确排除 if excludeID, ok := projection["_id"]; !ok || isTrue(excludeID) { projected["_id"] = doc.ID } result[i] = types.Document{ ID: doc.ID, Data: projected, } } return result } // generateID 生成唯一 ID(简化版本) func generateID() string { return engine.GenerateID() } // isTrue 检查值是否为真 func isTrue(v interface{}) bool { switch val := v.(type) { case bool: return val case int: return val != 0 case float64: return val != 0 } return true } // getNestedValue 获取嵌套字段值 func getNestedValue(doc map[string]interface{}, key string) interface{} { parts := strings.Split(key, ".") var current interface{} = doc for _, part := range parts { if m, ok := current.(map[string]interface{}); ok { current = m[part] } else { return nil } } return current }