gomog/internal/database/base.go

243 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package database
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// BaseAdapter 基础适配器实现
type BaseAdapter struct {
db *sql.DB
driverName string
}
// NewBaseAdapter 创建基础适配器
func NewBaseAdapter(driverName string) *BaseAdapter {
return &BaseAdapter{
driverName: driverName,
}
}
// getDB 获取数据库连接(供子类使用)
func (a *BaseAdapter) GetDB() *sql.DB {
return a.db
}
// Connect 连接数据库
func (a *BaseAdapter) Connect(ctx context.Context, dsn string) error {
db, err := sql.Open(a.driverName, dsn)
if err != nil {
return err
}
a.db = db
return db.PingContext(ctx)
}
// Close 关闭连接
func (a *BaseAdapter) Close() error {
if a.db != nil {
return a.db.Close()
}
return nil
}
// Ping 检查连接
func (a *BaseAdapter) Ping(ctx context.Context) error {
return a.db.PingContext(ctx)
}
// CreateCollection 创建集合(表)
func (a *BaseAdapter) CreateCollection(ctx context.Context, name string) error {
// 使用统一的表结构id, data(JSON), created_at, updated_at
query := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id TEXT PRIMARY KEY,
data JSON NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`, name)
_, err := a.db.ExecContext(ctx, query)
return err
}
// DropCollection 删除集合(表)
func (a *BaseAdapter) DropCollection(ctx context.Context, name string) error {
query := fmt.Sprintf("DROP TABLE IF EXISTS %s", name)
_, err := a.db.ExecContext(ctx, query)
return err
}
// CollectionExists 检查集合是否存在
func (a *BaseAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
// 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同
return false, ErrNotImplemented
}
// InsertMany 批量插入文档
func (a *BaseAdapter) InsertMany(ctx context.Context, collection string, docs []types.Document) error {
tx, err := a.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.PrepareContext(ctx,
fmt.Sprintf("INSERT INTO %s (id, data, created_at, updated_at) VALUES (?, ?, ?, ?)", collection))
if err != nil {
return err
}
defer stmt.Close()
for _, doc := range docs {
jsonData, err := json.Marshal(doc.Data)
if err != nil {
return err
}
now := time.Now()
_, err = stmt.ExecContext(ctx, doc.ID, jsonData, now, now)
if err != nil {
return err
}
}
return tx.Commit()
}
// UpdateMany 批量更新文档
func (a *BaseAdapter) UpdateMany(ctx context.Context, collection string, ids []string, update types.Update) error {
tx, err := a.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
// 构建更新 SQL
setClauses := make([]string, 0)
args := make([]interface{}, 0)
// 处理 $set
for field, value := range update.Set {
setClauses = append(setClauses, fmt.Sprintf("json_set(data, '$.%s', ?)", field))
args = append(args, toJSONString(value))
}
// 处理 $unset
for field := range update.Unset {
// SQLite/PostgreSQL 移除 JSON 字段的方式不同,这里简化处理
// 实际实现中需要根据具体数据库调整
setClauses = append(setClauses, fmt.Sprintf("json_remove(data, '$.%s')", field))
}
if len(setClauses) == 0 {
return nil
}
// 为每个 ID 执行更新
for _, id := range ids {
updateArgs := append([]interface{}{time.Now()}, args...)
updateArgs = append(updateArgs, id)
query := fmt.Sprintf(
"UPDATE %s SET data = %s, updated_at = ? WHERE id = ?",
collection,
setClauses[0], // 简化:只处理第一个 set 子句
)
_, err = tx.ExecContext(ctx, query, updateArgs...)
if err != nil {
return err
}
}
return tx.Commit()
}
// DeleteMany 批量删除文档
func (a *BaseAdapter) DeleteMany(ctx context.Context, collection string, ids []string) error {
if len(ids) == 0 {
return nil
}
// 构建 IN 子句
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
query := fmt.Sprintf(
"DELETE FROM %s WHERE id IN (%s)",
collection,
fmt.Sprintf("%s", placeholders),
)
_, err := a.db.ExecContext(ctx, query, args...)
return err
}
// FindAll 查询所有文档
func (a *BaseAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) {
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s", collection)
rows, err := a.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var docs []types.Document
for rows.Next() {
var doc types.Document
var jsonData []byte
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, err
}
if err := json.Unmarshal(jsonData, &doc.Data); err != nil {
return nil, err
}
docs = append(docs, doc)
}
return docs, rows.Err()
}
// BeginTx 开始事务
func (a *BaseAdapter) BeginTx(ctx context.Context) (Transaction, error) {
tx, err := a.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &baseTransaction{tx: tx}, nil
}
// baseTransaction 基础事务实现
type baseTransaction struct {
tx *sql.Tx
}
func (t *baseTransaction) Commit() error {
return t.tx.Commit()
}
func (t *baseTransaction) Rollback() error {
return t.tx.Rollback()
}
// toJSONString 将值转换为 JSON 字符串
func toJSONString(v interface{}) string {
if v == nil {
return "null"
}
data, _ := json.Marshal(v)
return string(data)
}