Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/cluster/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (c *Cluster) RegisterPlugin(lifetime plugin_entities.PluginLifetime) error
// do plugin state update immediately
err = c.doPluginStateUpdate(l)
if err != nil {
return errors.Join(err, errors.New("failed to update plugin state"))
return errors.Join(err, errors.New("failed to update plugin state"))
}

if c.showLog {
Expand Down
29 changes: 25 additions & 4 deletions internal/core/plugin_manager/packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"os"
"strings"

"github.com/langgenius/dify-plugin-daemon/internal/db"
"github.com/langgenius/dify-plugin-daemon/internal/types/models"
Expand Down Expand Up @@ -56,17 +57,27 @@ func (p *PluginManager) SavePackage(
return nil, err
}

// create plugin if not exists
// create plugin if not exists (idempotent under concurrency)
if _, err := db.GetOne[models.PluginDeclaration](
db.Equal("plugin_unique_identifier", uniqueIdentifier.String()),
); err == db.ErrDatabaseNotFound {
err = db.Create(&models.PluginDeclaration{
createErr := db.Create(&models.PluginDeclaration{
PluginUniqueIdentifier: uniqueIdentifier.String(),
PluginID: uniqueIdentifier.PluginID(),
Declaration: declaration,
})
if err != nil {
return nil, err
if createErr != nil {
// ignore Postgres unique-violation (23505) errors triggered by concurrent inserts
if isUniqueViolation(createErr) {
return &declaration, nil
}
// fallback: if another goroutine has just inserted, read-after-write should succeed
if _, again := db.GetOne[models.PluginDeclaration](
db.Equal("plugin_unique_identifier", uniqueIdentifier.String()),
); again == nil {
return &declaration, nil
}
return nil, createErr
Comment thread
fatelei marked this conversation as resolved.
}
} else if err != nil {
return nil, err
Expand All @@ -75,6 +86,16 @@ func (p *PluginManager) SavePackage(
return &declaration, nil
}

// isUniqueViolation returns true if err indicates a PostgreSQL unique constraint violation (SQLSTATE 23505).
// Works across common drivers by matching canonical substrings to avoid hard dependency on driver types.
func isUniqueViolation(err error) bool {
if err == nil {
return false
}
s := err.Error()
return strings.Contains(s, "SQLSTATE 23505") || strings.Contains(s, "duplicate key value violates unique constraint")
Comment thread
fatelei marked this conversation as resolved.
}

func (p *PluginManager) GetPackage(
plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/db/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ package db

// Note: The GetCache, UpdateCache, and DeleteCache functions that were previously
// in this file are deprecated and not used in the codebase.
// Direct cache operations should use the cache package (internal/utils/cache)
// Direct cache operations should use the cache package (internal/utils/cache)
12 changes: 6 additions & 6 deletions internal/service/unauthorized_langgenius_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func TestIsUnauthorizedLanggenius(t *testing.T) {
Author: tt.author,
},
}

got := isUnauthorizedLanggenius(declaration, tt.verification)
if got != tt.want {
t.Errorf("isUnauthorizedLanggenius() = %v, want %v", got, tt.want)
Expand All @@ -163,10 +163,10 @@ func TestIsUnauthorizedLanggenius_EdgeCases(t *testing.T) {
want: false, // spaces don't affect the comparison after lowercase
},
{
name: "langgenius with spaces but no verification",
author: " langgenius ",
name: "langgenius with spaces but no verification",
author: " langgenius ",
verification: nil,
want: false, // with spaces, not exact match after lowercase
want: false, // with spaces, not exact match after lowercase
},
{
name: "LaNgGeNiUs mixed case",
Expand All @@ -193,11 +193,11 @@ func TestIsUnauthorizedLanggenius_EdgeCases(t *testing.T) {
Author: tt.author,
},
}

got := isUnauthorizedLanggenius(declaration, tt.verification)
if got != tt.want {
t.Errorf("isUnauthorizedLanggenius() = %v, want %v for author=%q", got, tt.want, tt.author)
}
})
}
}
}
6 changes: 3 additions & 3 deletions internal/types/app/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (

func TestServerHostDefault(t *testing.T) {
tests := []struct {
name string
inputHost string
expectedHost string
name string
inputHost string
expectedHost string
}{
{
name: "empty host should default to 0.0.0.0",
Expand Down
23 changes: 19 additions & 4 deletions internal/types/models/curd/atomic.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,25 @@ func InstallPlugin(

err := db.Create(plugin, tx)
if err != nil {
return err
// Handle potential duplicate creation due to race: refetch and update refers
// to achieve idempotent behavior under concurrency.
p2, gerr := db.GetOne[models.Plugin](
db.WithTransactionContext(tx),
db.Equal("plugin_unique_identifier", pluginUniqueIdentifier.String()),
db.Equal("install_type", string(installType)),
db.WLock(),
)
if gerr != nil {
return err
}
p2.Refers++
if uerr := db.Update(&p2, tx); uerr != nil {
return uerr
}
pluginToBeReturns = &p2
} else {
pluginToBeReturns = plugin
}

pluginToBeReturns = plugin
} else if err != nil {
return err
} else {
Expand Down Expand Up @@ -604,7 +619,7 @@ func UpgradePlugin(
if err != nil {
return nil, err
}
pluginId := newPluginUniqueIdentifier.PluginID() // get the pluginId
pluginId := newPluginUniqueIdentifier.PluginID() // get the pluginId
pluginInstallationCacheKey := helper.PluginInstallationCacheKey(pluginId, tenantId) // make cache key
if _, err = cache.AutoDelete[models.PluginInstallation](pluginInstallationCacheKey); err != nil {
return nil, err
Expand Down
90 changes: 90 additions & 0 deletions internal/types/models/curd/install_concurrency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package curd

import (
"strings"
"sync"
"testing"

"github.com/google/uuid"
"github.com/langgenius/dify-plugin-daemon/internal/db"
"github.com/langgenius/dify-plugin-daemon/internal/types/app"
"github.com/langgenius/dify-plugin-daemon/internal/types/models"
"github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
"github.com/stretchr/testify/require"
)

// TestInstallPlugin_IdempotentUnderConcurrency ensures creating the same plugin/installation
// concurrently is idempotent: only one plugin and one installation row are persisted.
func TestInstallPlugin_IdempotentUnderConcurrency(t *testing.T) {
cfg := &app.Config{
DBType: app.DB_TYPE_POSTGRESQL,
DBUsername: "postgres",
DBPassword: "difyai123456",
DBHost: "localhost",
DBPort: 5432,
DBDatabase: "dify_plugin_daemon",
DBSslMode: "disable",
}
cfg.SetDefault()
db.Init(cfg)
t.Cleanup(db.Close)

tenantID := uuid.NewString()
pluginName := "concurrency_demo_" + uuid.NewString()
checksum := uuid.NewString()
checksum = strings.ReplaceAll(checksum, "-", "")
// 32 hex chars
if len(checksum) > 32 {
checksum = checksum[:32]
}

identifier, err := plugin_entities.NewPluginUniqueIdentifier("tester/" + pluginName + ":1.0.0.0@" + checksum)
require.NoError(t, err)

const workers = 8
var wg sync.WaitGroup
wg.Add(workers)
errs := make(chan error, workers)
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
_, _, err := InstallPlugin(
tenantID,
identifier,
plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
&plugin_entities.PluginDeclaration{},
"unittest",
map[string]any{"from": "test"},
)
errs <- err
}()
}
wg.Wait()
close(errs)

// Validate DB state: exactly one plugin and one installation persisted
plugins, err := db.GetAll[models.Plugin](
db.Equal("plugin_unique_identifier", identifier.String()),
db.Equal("install_type", string(plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL)),
)
require.NoError(t, err)
require.Len(t, plugins, 1, "should persist exactly one plugin record")

installations, err := db.GetAll[models.PluginInstallation](
db.Equal("plugin_unique_identifier", identifier.String()),
db.Equal("tenant_id", tenantID),
)
require.NoError(t, err)
require.Len(t, installations, 1, "should persist exactly one installation record for tenant")

// A subsequent sequential install should be rejected as already installed
_, _, err = InstallPlugin(
tenantID,
identifier,
plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
&plugin_entities.PluginDeclaration{},
"unittest",
map[string]any{"from": "test"},
)
require.ErrorIs(t, err, ErrPluginAlreadyInstalled)
}
3 changes: 2 additions & 1 deletion internal/types/models/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import (
type Plugin struct {
Model
// PluginUniqueIdentifier is a unique identifier for the plugin, it contains version and checksum
PluginUniqueIdentifier string `json:"plugin_unique_identifier" gorm:"index;size:255"`
// Enforce uniqueness to guarantee idempotency under concurrency
PluginUniqueIdentifier string `json:"plugin_unique_identifier" gorm:"size:255;uniqueIndex:idx_plugin_unique_identifier"`
// PluginID is the id of the plugin, only plugin name is considered
PluginID string `json:"id" gorm:"index;size:255"`
Refers int `json:"refers" gorm:"default:0"`
Expand Down
2 changes: 1 addition & 1 deletion internal/types/models/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ type TriggerInstallation struct {
Provider string `json:"provider" gorm:"column:provider;size:127;index;not null"`
PluginUniqueIdentifier string `json:"plugin_unique_identifier" gorm:"index;size:255"`
PluginID string `json:"plugin_id" gorm:"index;size:255"`
}
}
Loading