Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions db_changes/db/dialect_clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,10 @@ func convertOpToClickhouseValues(o *Operation) ([]any, error) {
values := make([]any, len(o.data))
for i, v := range columns {
if col, exists := o.table.columnsByName[v]; exists {
convertedType, err := convertToType(o.data[v], col.scanType)
fieldData := o.data[v]
convertedType, err := convertToType(fieldData.Value, col.scanType)
if err != nil {
return nil, fmt.Errorf("converting value %q to type %q in column %q: %w", o.data[v], col.scanType, v, err)
return nil, fmt.Errorf("converting value %q to type %q in column %q: %w", fieldData.Value, col.scanType, v, err)
}
values[i] = convertedType
} else {
Expand Down
85 changes: 76 additions & 9 deletions db_changes/db/dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,30 @@ func (d PostgresDialect) saveRow(op, schema, escapedTableName string, primaryKey

}

// getResultCast returns the appropriate cast suffix for the result of arithmetic operations
// based on the column's scan type. TEXT columns need ::text cast, numeric types don't need cast.
func getResultCast(scanType reflect.Type) string {
if scanType == nil {
return "" // unknown type, let PostgreSQL handle it
}
switch scanType.Kind() {
case reflect.String:
return "::text" // TEXT columns need explicit cast from numeric
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
return "" // numeric types don't need cast, PostgreSQL will handle it
default:
return "" // unknown type, let PostgreSQL handle it
}
}

func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (normalQuery string, undoQuery string, err error) {
var columns, values []string
var updateOps []UpdateOp
var scanTypes []reflect.Type
if o.opType == OperationTypeInsert || o.opType == OperationTypeUpsert || o.opType == OperationTypeUpdate {
columns, values, err = d.prepareColValues(o.table, o.data)
columns, values, updateOps, scanTypes, err = d.prepareColValues(o.table, o.data)
if err != nil {
return "", "", fmt.Errorf("preparing column & values: %w", err)
}
Expand All @@ -415,9 +435,30 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (normalQ
return insertQuery, "", nil

case OperationTypeUpsert:
// Build per-field update expressions based on UpdateOp
updates := make([]string, len(columns))
for i := range columns {
updates[i] = fmt.Sprintf("%s=EXCLUDED.%s", columns[i], columns[i])
col := columns[i]
resultCast := getResultCast(scanTypes[i])
switch updateOps[i] {
case UpdateOpSet:
// Direct assignment: col = EXCLUDED.col
updates[i] = fmt.Sprintf("%s=EXCLUDED.%s", col, col)
case UpdateOpAdd:
// Accumulate: col = COALESCE(col, 0) + EXCLUDED.col
updates[i] = fmt.Sprintf("%s=(COALESCE(%s.%s::numeric, 0) + EXCLUDED.%s::numeric)%s", col, o.table.nameEscaped, col, col, resultCast)
case UpdateOpMax:
// Maximum: col = GREATEST(COALESCE(col, 0), EXCLUDED.col)
updates[i] = fmt.Sprintf("%s=GREATEST(COALESCE(%s.%s::numeric, 0), EXCLUDED.%s::numeric)%s", col, o.table.nameEscaped, col, col, resultCast)
case UpdateOpMin:
// Minimum: col = LEAST(COALESCE(col, 0), EXCLUDED.col)
updates[i] = fmt.Sprintf("%s=LEAST(COALESCE(%s.%s::numeric, 0), EXCLUDED.%s::numeric)%s", col, o.table.nameEscaped, col, col, resultCast)
case UpdateOpSetIfNull:
// Set only if NULL (first value wins): col = COALESCE(col, EXCLUDED.col)
updates[i] = fmt.Sprintf("%s=COALESCE(%s.%s, EXCLUDED.%s)", col, o.table.nameEscaped, col, col)
default:
updates[i] = fmt.Sprintf("%s=EXCLUDED.%s", col, col)
}
}

// Escape primary key column names to preserve case sensitivity (e.g., camelCase)
Expand All @@ -441,9 +482,31 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (normalQ
return insertQuery, "", nil

case OperationTypeUpdate:
// Build per-field update expressions based on UpdateOp
updates := make([]string, len(columns))
for i := 0; i < len(columns); i++ {
updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i])
for i := range columns {
col := columns[i]
val := values[i]
resultCast := getResultCast(scanTypes[i])
switch updateOps[i] {
case UpdateOpSet:
// Direct assignment: col = value
updates[i] = fmt.Sprintf("%s=%s", col, val)
case UpdateOpAdd:
// Accumulate: col = COALESCE(col, 0) + value
updates[i] = fmt.Sprintf("%s=(COALESCE(%s::numeric, 0) + %s::numeric)%s", col, col, val, resultCast)
case UpdateOpMax:
// Maximum: col = GREATEST(COALESCE(col, 0), value)
updates[i] = fmt.Sprintf("%s=GREATEST(COALESCE(%s::numeric, 0), %s::numeric)%s", col, col, val, resultCast)
case UpdateOpMin:
// Minimum: col = LEAST(COALESCE(col, 0), value)
updates[i] = fmt.Sprintf("%s=LEAST(COALESCE(%s::numeric, 0), %s::numeric)%s", col, col, val, resultCast)
case UpdateOpSetIfNull:
// Set only if NULL (first value wins): col = COALESCE(col, value)
updates[i] = fmt.Sprintf("%s=COALESCE(%s, %s)", col, col, val)
default:
updates[i] = fmt.Sprintf("%s=%s", col, val)
}
}

primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey, "")
Expand Down Expand Up @@ -475,13 +538,15 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (normalQ
}
}

func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[string]string) (columns []string, values []string, err error) {
func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[string]FieldData) (columns []string, values []string, updateOps []UpdateOp, scanTypes []reflect.Type, err error) {
if len(colValues) == 0 {
return
}

columns = make([]string, len(colValues))
values = make([]string, len(colValues))
updateOps = make([]UpdateOp, len(colValues))
scanTypes = make([]reflect.Type, len(colValues))

i := 0
for colName := range colValues {
Expand All @@ -491,19 +556,21 @@ func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[strin
sort.Strings(columns) // sorted for determinism in tests

for i, columnName := range columns {
value := colValues[columnName]
fieldData := colValues[columnName]
columnInfo, found := table.columnsByName[columnName]
if !found {
return nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", "))
return nil, nil, nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", "))
}

normalizedValue, err := d.normalizeValueType(value, columnInfo.scanType)
normalizedValue, err := d.normalizeValueType(fieldData.Value, columnInfo.scanType)
if err != nil {
return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, value, err)
return nil, nil, nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, fieldData.Value, err)
}

values[i] = normalizedValue
columns[i] = columnInfo.escapedName // escape the column name
updateOps[i] = fieldData.UpdateOp
scanTypes[i] = columnInfo.scanType
}
return
}
Expand Down
153 changes: 153 additions & 0 deletions db_changes/db/dialect_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,157 @@ func TestRevertOp(t *testing.T) {
})
}

}

// TestPrepareStatement_UpdateOp tests SQL generation for each UpdateOp type
func TestPrepareStatement_UpdateOp(t *testing.T) {
// Create a test table with numeric column
table := createTestTable(t, "test_table", "id", "amount")

tests := []struct {
name string
opType OperationType
updateOp UpdateOp
value string
expectSQL string // substring to check in generated SQL
}{
// UPSERT with different UpdateOps
{
name: "UPSERT SET",
opType: OperationTypeUpsert,
updateOp: UpdateOpSet,
value: "100",
expectSQL: `"amount"=EXCLUDED."amount"`,
},
{
name: "UPSERT ADD",
opType: OperationTypeUpsert,
updateOp: UpdateOpAdd,
value: "100",
expectSQL: `"amount"=(COALESCE("test_table"."amount"::numeric, 0) + EXCLUDED."amount"::numeric)`,
},
{
name: "UPSERT MAX",
opType: OperationTypeUpsert,
updateOp: UpdateOpMax,
value: "100",
expectSQL: `"amount"=GREATEST(COALESCE("test_table"."amount"::numeric, 0), EXCLUDED."amount"::numeric)`,
},
{
name: "UPSERT MIN",
opType: OperationTypeUpsert,
updateOp: UpdateOpMin,
value: "100",
expectSQL: `"amount"=LEAST(COALESCE("test_table"."amount"::numeric, 0), EXCLUDED."amount"::numeric)`,
},
{
name: "UPSERT SET_IF_NULL",
opType: OperationTypeUpsert,
updateOp: UpdateOpSetIfNull,
value: "100",
expectSQL: `"amount"=COALESCE("test_table"."amount", EXCLUDED."amount")`,
},

// UPDATE with different UpdateOps
{
name: "UPDATE SET",
opType: OperationTypeUpdate,
updateOp: UpdateOpSet,
value: "100",
expectSQL: `"amount"=100`,
},
{
name: "UPDATE ADD",
opType: OperationTypeUpdate,
updateOp: UpdateOpAdd,
value: "100",
expectSQL: `"amount"=(COALESCE("amount"::numeric, 0) + 100::numeric)`,
},
{
name: "UPDATE MAX",
opType: OperationTypeUpdate,
updateOp: UpdateOpMax,
value: "100",
expectSQL: `"amount"=GREATEST(COALESCE("amount"::numeric, 0), 100::numeric)`,
},
{
name: "UPDATE MIN",
opType: OperationTypeUpdate,
updateOp: UpdateOpMin,
value: "100",
expectSQL: `"amount"=LEAST(COALESCE("amount"::numeric, 0), 100::numeric)`,
},
{
name: "UPDATE SET_IF_NULL",
opType: OperationTypeUpdate,
updateOp: UpdateOpSetIfNull,
value: "100",
expectSQL: `"amount"=COALESCE("amount", 100)`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dialect := PostgresDialect{schemaName: "public"}
op := &Operation{
table: table,
opType: tt.opType,
primaryKey: map[string]string{"id": "123"},
data: map[string]FieldData{
"amount": {Value: tt.value, UpdateOp: tt.updateOp},
},
}

sql, _, err := dialect.prepareStatement("public", op)
require.NoError(t, err)
assert.Contains(t, sql, tt.expectSQL, "SQL should contain expected UpdateOp clause")
})
}
}

// TestPrepareStatement_INSERT_IgnoresUpdateOp tests that INSERT ignores UpdateOp (always direct values)
func TestPrepareStatement_INSERT_IgnoresUpdateOp(t *testing.T) {
table := createTestTable(t, "test_table", "id", "amount")

// For INSERT, UpdateOp should not affect the SQL - it's always a direct INSERT
ops := []UpdateOp{UpdateOpSet, UpdateOpAdd, UpdateOpMax, UpdateOpMin, UpdateOpSetIfNull}

for _, updateOp := range ops {
t.Run(updateOpName(updateOp), func(t *testing.T) {
dialect := PostgresDialect{schemaName: "public"}
op := &Operation{
table: table,
opType: OperationTypeInsert,
primaryKey: map[string]string{"id": "123"},
data: map[string]FieldData{
"amount": {Value: "100", UpdateOp: updateOp},
},
}

sql, _, err := dialect.prepareStatement("public", op)
require.NoError(t, err)
// INSERT should always be a simple INSERT regardless of UpdateOp
assert.Contains(t, sql, "INSERT INTO")
assert.Contains(t, sql, "VALUES")
assert.NotContains(t, sql, "ON CONFLICT", "INSERT should not have ON CONFLICT clause")
})
}
}

// createTestTable creates a TableInfo for testing with numeric columns
func createTestTable(t *testing.T, name, pkCol string, extraCols ...string) *TableInfo {
t.Helper()
columns := make(map[string]*ColumnInfo)

// Primary key column (text)
columns[pkCol] = NewColumnInfo(pkCol, "text", "")

// Extra columns (numeric for UpdateOp testing)
for _, col := range extraCols {
columns[col] = NewColumnInfo(col, "numeric", int64(0))
}

table, err := NewTableInfo("public", name, []string{pkCol}, columns)
require.NoError(t, err)
return table
}
Loading