452 lines
11 KiB
Go
452 lines
11 KiB
Go
package engine
|
||
|
||
import (
|
||
"strings"
|
||
"time"
|
||
|
||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||
)
|
||
|
||
// applyUpdate 应用更新操作到文档数据
|
||
func applyUpdate(data map[string]interface{}, update types.Update, isUpsertInsert bool) map[string]interface{} {
|
||
return applyUpdateWithFilters(data, update, isUpsertInsert, nil)
|
||
}
|
||
|
||
// applyUpdateWithFilters 应用更新操作(支持 arrayFilters)
|
||
func applyUpdateWithFilters(data map[string]interface{}, update types.Update, isUpsertInsert bool, arrayFilters []types.Filter) map[string]interface{} {
|
||
// 深拷贝原数据
|
||
result := deepCopyMap(data)
|
||
|
||
// 处理 $set
|
||
for field, value := range update.Set {
|
||
if !updateArrayElement(result, field, value, convertFiltersToMaps(arrayFilters)) {
|
||
setNestedValue(result, field, value)
|
||
}
|
||
}
|
||
|
||
// 处理 $unset
|
||
for field := range update.Unset {
|
||
removeNestedValue(result, field)
|
||
}
|
||
|
||
// 处理 $inc
|
||
for field, value := range update.Inc {
|
||
if !updateArrayElement(result, field, value, convertFiltersToMaps(arrayFilters)) {
|
||
incNestedValue(result, field, value)
|
||
}
|
||
}
|
||
|
||
// 处理 $mul
|
||
for field, value := range update.Mul {
|
||
if !updateArrayElement(result, field, value, convertFiltersToMaps(arrayFilters)) {
|
||
mulNestedValue(result, field, value)
|
||
}
|
||
}
|
||
|
||
// 处理 $push
|
||
for field, value := range update.Push {
|
||
pushNestedValue(result, field, value)
|
||
}
|
||
|
||
// 处理 $pull
|
||
for field, value := range update.Pull {
|
||
pullNestedValue(result, field, value)
|
||
}
|
||
|
||
// 处理 $min - 仅当值小于当前值时更新
|
||
for field, value := range update.Min {
|
||
current := getNestedValue(result, field)
|
||
if current == nil || compareNumbers(current, value) > 0 {
|
||
setNestedValue(result, field, value)
|
||
}
|
||
}
|
||
|
||
// 处理 $max - 仅当值大于当前值时更新
|
||
for field, value := range update.Max {
|
||
current := getNestedValue(result, field)
|
||
if current == nil || compareNumbers(current, value) < 0 {
|
||
setNestedValue(result, field, value)
|
||
}
|
||
}
|
||
|
||
// 处理 $rename - 重命名字段
|
||
for oldName, newName := range update.Rename {
|
||
value := getNestedValue(result, oldName)
|
||
if value != nil {
|
||
removeNestedValue(result, oldName)
|
||
setNestedValue(result, newName, value)
|
||
}
|
||
}
|
||
|
||
// 处理 $currentDate - 设置为当前时间
|
||
for field, spec := range update.CurrentDate {
|
||
var currentTime interface{} = time.Now()
|
||
|
||
// 检查是否指定了类型
|
||
if specMap, ok := spec.(map[string]interface{}); ok {
|
||
if typeVal, exists := specMap["$type"]; exists {
|
||
if typeStr, ok := typeVal.(string); ok && typeStr == "timestamp" {
|
||
currentTime = time.Now().UnixMilli()
|
||
}
|
||
}
|
||
}
|
||
|
||
setNestedValue(result, field, currentTime)
|
||
}
|
||
|
||
// 处理 $addToSet - 添加唯一元素到数组
|
||
for field, value := range update.AddToSet {
|
||
current := getNestedValue(result, field)
|
||
var arr []interface{}
|
||
if current != nil {
|
||
if a, ok := current.([]interface{}); ok {
|
||
arr = a
|
||
}
|
||
}
|
||
if arr == nil {
|
||
arr = make([]interface{}, 0)
|
||
}
|
||
|
||
// 检查是否已存在
|
||
exists := false
|
||
for _, item := range arr {
|
||
if compareEq(item, value) {
|
||
exists = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !exists {
|
||
arr = append(arr, value)
|
||
setNestedValue(result, field, arr)
|
||
}
|
||
}
|
||
|
||
// 处理 $pop - 移除数组首/尾元素
|
||
for field, pos := range update.Pop {
|
||
current := getNestedValue(result, field)
|
||
if arr, ok := current.([]interface{}); ok && len(arr) > 0 {
|
||
if pos >= 0 {
|
||
// 移除最后一个元素
|
||
arr = arr[:len(arr)-1]
|
||
} else {
|
||
// 移除第一个元素
|
||
arr = arr[1:]
|
||
}
|
||
setNestedValue(result, field, arr)
|
||
}
|
||
}
|
||
|
||
// 处理 $pullAll - 从数组中移除多个值
|
||
for field, values := range update.PullAll {
|
||
current := getNestedValue(result, field)
|
||
if arr, ok := current.([]interface{}); ok {
|
||
filtered := make([]interface{}, 0, len(arr))
|
||
for _, item := range arr {
|
||
keep := true
|
||
for _, removeVal := range values {
|
||
if compareEq(item, removeVal) {
|
||
keep = false
|
||
break
|
||
}
|
||
}
|
||
if keep {
|
||
filtered = append(filtered, item)
|
||
}
|
||
}
|
||
setNestedValue(result, field, filtered)
|
||
}
|
||
}
|
||
|
||
// 处理 $setOnInsert - 仅在 upsert 插入时设置
|
||
if isUpsertInsert {
|
||
for field, value := range update.SetOnInsert {
|
||
setNestedValue(result, field, value)
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// convertFiltersToMaps 转换 Filter 数组为 map 数组
|
||
func convertFiltersToMaps(filters []types.Filter) []map[string]interface{} {
|
||
if filters == nil {
|
||
return nil
|
||
}
|
||
result := make([]map[string]interface{}, len(filters))
|
||
for i, f := range filters {
|
||
result[i] = map[string]interface{}(f)
|
||
}
|
||
return result
|
||
}
|
||
|
||
// deepCopyMap 深拷贝 map
|
||
func deepCopyMap(m map[string]interface{}) map[string]interface{} {
|
||
if m == nil {
|
||
return nil
|
||
}
|
||
|
||
result := make(map[string]interface{})
|
||
for k, v := range m {
|
||
switch val := v.(type) {
|
||
case map[string]interface{}:
|
||
result[k] = deepCopyMap(val)
|
||
case []interface{}:
|
||
result[k] = deepCopySlice(val)
|
||
default:
|
||
result[k] = v
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// deepCopySlice 深拷贝 slice
|
||
func deepCopySlice(s []interface{}) []interface{} {
|
||
if s == nil {
|
||
return nil
|
||
}
|
||
|
||
result := make([]interface{}, len(s))
|
||
for i, v := range s {
|
||
switch val := v.(type) {
|
||
case map[string]interface{}:
|
||
result[i] = deepCopyMap(val)
|
||
case []interface{}:
|
||
result[i] = deepCopySlice(val)
|
||
default:
|
||
result[i] = v
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// setNestedValue 设置嵌套字段值
|
||
func setNestedValue(data map[string]interface{}, field string, value interface{}) {
|
||
parts := splitFieldPath(field)
|
||
|
||
current := data
|
||
for i, part := range parts {
|
||
if i == len(parts)-1 {
|
||
// 最后一个部分,设置值
|
||
current[part] = value
|
||
return
|
||
}
|
||
|
||
// 中间部分,确保是 map
|
||
if current[part] == nil {
|
||
current[part] = make(map[string]interface{})
|
||
}
|
||
|
||
if m, ok := current[part].(map[string]interface{}); ok {
|
||
current = m
|
||
} else {
|
||
// 类型不匹配,创建新 map
|
||
newMap := make(map[string]interface{})
|
||
current[part] = newMap
|
||
current = newMap
|
||
}
|
||
}
|
||
}
|
||
|
||
// removeNestedValue 删除嵌套字段
|
||
func removeNestedValue(data map[string]interface{}, field string) {
|
||
parts := splitFieldPath(field)
|
||
|
||
current := data
|
||
for i, part := range parts {
|
||
if i == len(parts)-1 {
|
||
delete(current, part)
|
||
return
|
||
}
|
||
|
||
if m, ok := current[part].(map[string]interface{}); ok {
|
||
current = m
|
||
} else {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// incNestedValue 递增嵌套字段值
|
||
func incNestedValue(data map[string]interface{}, field string, increment interface{}) {
|
||
current := getNestedValue(data, field)
|
||
if current == nil {
|
||
setNestedValue(data, field, increment)
|
||
return
|
||
}
|
||
|
||
newValue := toFloat64(current) + toFloat64(increment)
|
||
setNestedValue(data, field, newValue)
|
||
}
|
||
|
||
// mulNestedValue 乘以嵌套字段值
|
||
func mulNestedValue(data map[string]interface{}, field string, multiplier interface{}) {
|
||
current := getNestedValue(data, field)
|
||
if current == nil {
|
||
return
|
||
}
|
||
|
||
newValue := toFloat64(current) * toFloat64(multiplier)
|
||
setNestedValue(data, field, newValue)
|
||
}
|
||
|
||
// pushNestedValue 推入数组
|
||
func pushNestedValue(data map[string]interface{}, field string, value interface{}) {
|
||
current := getNestedValue(data, field)
|
||
|
||
var arr []interface{}
|
||
if current != nil {
|
||
if a, ok := current.([]interface{}); ok {
|
||
arr = a
|
||
}
|
||
}
|
||
|
||
if arr == nil {
|
||
arr = make([]interface{}, 0)
|
||
}
|
||
|
||
arr = append(arr, value)
|
||
setNestedValue(data, field, arr)
|
||
}
|
||
|
||
// pullNestedValue 从数组中移除
|
||
func pullNestedValue(data map[string]interface{}, field string, value interface{}) {
|
||
current := getNestedValue(data, field)
|
||
if current == nil {
|
||
return
|
||
}
|
||
|
||
arr, ok := current.([]interface{})
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
// 过滤掉匹配的值
|
||
filtered := make([]interface{}, 0, len(arr))
|
||
for _, item := range arr {
|
||
if !compareEq(item, value) {
|
||
filtered = append(filtered, item)
|
||
}
|
||
}
|
||
|
||
setNestedValue(data, field, filtered)
|
||
}
|
||
|
||
// splitFieldPath 分割字段路径(支持 "a.b.c" 格式)
|
||
func splitFieldPath(field string) []string {
|
||
// 简单实现,不考虑转义情况
|
||
return strings.Split(field, ".")
|
||
}
|
||
|
||
// generateID 生成唯一 ID
|
||
func generateID() string {
|
||
return time.Now().Format("20060102150405.000000000")
|
||
}
|
||
|
||
// updateArrayElement 更新数组元素(支持 $ 位置操作符)
|
||
func updateArrayElement(data map[string]interface{}, field string, value interface{}, arrayFilters []map[string]interface{}) bool {
|
||
parts := splitFieldPath(field)
|
||
|
||
// 查找包含 $ 或 $[] 的部分
|
||
for i, part := range parts {
|
||
if part == "$" || part == "$[]" || (len(part) > 2 && part[0] == '$' && part[1] == '[') {
|
||
// 需要数组更新
|
||
return updateArrayAtPath(data, parts, i, value, arrayFilters)
|
||
}
|
||
}
|
||
|
||
// 普通字段更新
|
||
setNestedValue(data, field, value)
|
||
return true
|
||
}
|
||
|
||
// updateArrayAtPath 在指定路径更新数组
|
||
func updateArrayAtPath(data map[string]interface{}, parts []string, index int, value interface{}, arrayFilters []map[string]interface{}) bool {
|
||
// 获取到数组前的路径(导航到父对象)
|
||
current := data
|
||
for i := 0; i < index; i++ {
|
||
if m, ok := current[parts[i]].(map[string]interface{}); ok {
|
||
current = m
|
||
} else if i == index-1 {
|
||
// 最后一个部分应该是数组字段名,不需要是 map
|
||
break
|
||
} else {
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 获取实际的数组字段名(操作符前面的部分)
|
||
var actualFieldName string
|
||
if index > 0 {
|
||
actualFieldName = parts[index-1]
|
||
} else {
|
||
return false // 无效的路径
|
||
}
|
||
|
||
arrField := parts[index]
|
||
arr := getNestedValue(data, actualFieldName)
|
||
array, ok := arr.([]interface{})
|
||
if !ok || len(array) == 0 {
|
||
return false
|
||
}
|
||
|
||
// 处理不同的位置操作符
|
||
if arrField == "$" {
|
||
// 定位第一个匹配的元素(需要配合查询条件)
|
||
// 简化实现:更新第一个元素
|
||
array[0] = value
|
||
setNestedValue(data, actualFieldName, array)
|
||
return true
|
||
}
|
||
|
||
if arrField == "$[]" {
|
||
// 更新所有元素
|
||
for i := range array {
|
||
array[i] = value
|
||
}
|
||
setNestedValue(data, actualFieldName, array)
|
||
return true
|
||
}
|
||
|
||
// 处理 $[identifier] 形式
|
||
if len(arrField) > 3 && arrField[0] == '$' && arrField[1] == '[' && arrField[len(arrField)-1] == ']' {
|
||
identifier := arrField[2 : len(arrField)-1]
|
||
|
||
// 查找匹配的 arrayFilter
|
||
var filter map[string]interface{}
|
||
for _, f := range arrayFilters {
|
||
if idVal, exists := f["identifier"]; exists && idVal == identifier {
|
||
// 复制 filter 并移除 identifier 字段
|
||
filter = make(map[string]interface{})
|
||
for k, v := range f {
|
||
if k != "identifier" {
|
||
filter[k] = v
|
||
}
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
if filter != nil && len(filter) > 0 {
|
||
// 应用过滤器更新匹配的元素
|
||
for i, item := range array {
|
||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||
if MatchFilter(itemMap, filter) {
|
||
// 如果是嵌套字段(如 students.$[elem].grade),需要设置嵌套字段
|
||
if index+1 < len(parts) {
|
||
// 还有后续字段,设置嵌套字段
|
||
itemMap[parts[index+1]] = value
|
||
} else {
|
||
array[i] = value
|
||
}
|
||
}
|
||
}
|
||
}
|
||
setNestedValue(data, actualFieldName, array)
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|