179 lines
4.2 KiB
Go
179 lines
4.2 KiB
Go
package postgres
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"git.kingecg.top/kingecg/gomog/internal/database"
|
||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||
_ "github.com/lib/pq"
|
||
)
|
||
|
||
// PostgresAdapter PostgreSQL 数据库适配器
|
||
type PostgresAdapter struct {
|
||
*database.BaseAdapter
|
||
}
|
||
|
||
// NewPostgresAdapter 创建 PostgreSQL 适配器
|
||
func NewPostgresAdapter() *PostgresAdapter {
|
||
return &PostgresAdapter{
|
||
BaseAdapter: database.NewBaseAdapter("postgres"),
|
||
}
|
||
}
|
||
|
||
// Connect 连接 PostgreSQL 数据库
|
||
func (a *PostgresAdapter) Connect(ctx context.Context, dsn string) error {
|
||
if err := a.BaseAdapter.Connect(ctx, dsn); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 设置 PostgreSQL 会话参数
|
||
_, err := a.GetDB().Exec("SET timezone = 'UTC'")
|
||
return err
|
||
}
|
||
|
||
// CreateCollection 创建集合(PostgreSQL 表)
|
||
func (a *PostgresAdapter) CreateCollection(ctx context.Context, name string) error {
|
||
// PostgreSQL 使用 JSONB 类型(二进制 JSON,更高效)
|
||
query := fmt.Sprintf(`
|
||
CREATE TABLE IF NOT EXISTS %s (
|
||
id TEXT PRIMARY KEY,
|
||
data JSONB NOT NULL,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||
)`, name)
|
||
|
||
_, err := a.GetDB().ExecContext(ctx, query)
|
||
return err
|
||
}
|
||
|
||
// CollectionExists 检查集合是否存在
|
||
func (a *PostgresAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
|
||
query := `SELECT COUNT(*) FROM information_schema.tables
|
||
WHERE table_schema = 'public' AND table_name = $1`
|
||
var count int
|
||
err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&count)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count > 0, nil
|
||
}
|
||
|
||
// FindAll 查询所有文档(使用 PostgreSQL JSONB)
|
||
func (a *PostgresAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) {
|
||
query := fmt.Sprintf("SELECT id, data::text, created_at, updated_at FROM %s", collection)
|
||
rows, err := a.GetDB().QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var docs []types.Document
|
||
for rows.Next() {
|
||
var doc types.Document
|
||
var jsonData string
|
||
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
docs = append(docs, doc)
|
||
}
|
||
|
||
return docs, rows.Err()
|
||
}
|
||
|
||
// InsertMany 批量插入(PostgreSQL 优化版本)
|
||
func (a *PostgresAdapter) InsertMany(ctx context.Context, collection string, docs []types.Document) error {
|
||
tx, err := a.GetDB().BeginTx(ctx, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
for _, doc := range docs {
|
||
jsonData, err := json.Marshal(doc.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
query := fmt.Sprintf(
|
||
"INSERT INTO %s (id, data, created_at, updated_at) VALUES ($1, $2::jsonb, $3, $4)",
|
||
collection,
|
||
)
|
||
|
||
now := doc.CreatedAt
|
||
if now.IsZero() {
|
||
now = doc.UpdatedAt
|
||
}
|
||
if now.IsZero() {
|
||
now = doc.UpdatedAt
|
||
}
|
||
|
||
_, err = tx.ExecContext(ctx, query, doc.ID, string(jsonData), now, now)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|
||
|
||
// UpdateMany 批量更新(使用 PostgreSQL JSONB 操作符)
|
||
func (a *PostgresAdapter) UpdateMany(ctx context.Context, collection string, ids []string, update types.Update) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
|
||
tx, err := a.GetDB().BeginTx(ctx, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
// 构建更新表达式
|
||
updateExpr := "data"
|
||
args := make([]interface{}, 0)
|
||
argIndex := 1
|
||
|
||
// 处理 $set - 使用 JSONB 合并
|
||
if len(update.Set) > 0 {
|
||
setJSON, _ := json.Marshal(update.Set)
|
||
updateExpr = fmt.Sprintf("%s || $%d::jsonb", updateExpr, argIndex)
|
||
args = append(args, string(setJSON))
|
||
argIndex++
|
||
}
|
||
|
||
// 处理 $unset - 使用 JSONB 减号操作符
|
||
for field := range update.Unset {
|
||
updateExpr = fmt.Sprintf("%s - $%d", updateExpr, argIndex)
|
||
args = append(args, field)
|
||
argIndex++
|
||
}
|
||
|
||
// 为每个 ID 执行更新
|
||
for _, id := range ids {
|
||
query := fmt.Sprintf(
|
||
"UPDATE %s SET data = %s, updated_at = $%d WHERE id = $%d",
|
||
collection,
|
||
updateExpr,
|
||
argIndex,
|
||
argIndex+1,
|
||
)
|
||
|
||
finalArgs := append(args, time.Now(), id)
|
||
_, err = tx.ExecContext(ctx, query, finalArgs...)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|