Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,7 @@
#
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
# LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m

# allowed access ip config, ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
# export LOCALAI_IP_ALLOWLIST="192.168.1.0/24,10.0.0.1,127.0.0.1"
# LOCALAI_IP_ALLOWLIST=192.168.1.0/24
2 changes: 2 additions & 0 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type RunCMD struct {
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"`

Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
IpAllowList string `env:"LOCALAI_IP_ALLOWLIST,IP_ALLOWLIST" help:"A list of IP addresses or CIDR ranges to allow access" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
Expand Down Expand Up @@ -127,6 +128,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
config.WithLoadToMemory(r.LoadToMemory),
config.WithMachineTag(r.MachineTag),
config.WithIPAllowList(r.IpAllowList),
}

if r.DisableMetricsEndpoint {
Expand Down
14 changes: 14 additions & 0 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ type ApplicationConfig struct {
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration

MachineTag string

// ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
IpAllowList string

IPAllowListHelper *utils.IPAllowList
}

type AppOption func(*ApplicationConfig)
Expand Down Expand Up @@ -128,6 +133,15 @@ func WithP2PToken(s string) AppOption {
}
}

func WithIPAllowList(s string) AppOption {
return func(o *ApplicationConfig) {
log.Info().Msgf("Application IpAllowList($LOCALAI_IP_ALLOWLIST): %s", s)
o.IpAllowList = s
var ipAllowListHelper, _ = utils.NewIPAllowList(s)
o.IPAllowListHelper = ipAllowListHelper
}
}

var EnableWatchDog = func(o *ApplicationConfig) {
o.WatchDog = true
}
Expand Down
100 changes: 100 additions & 0 deletions core/http/IPAllowList.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package http

import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
)

type IPAllowList struct {
allowList string
cidrs []*net.IPNet
ips []net.IP
mu sync.RWMutex
enabled bool
}

func NewIPAllowList(allowList string) (*IPAllowList, error) {

w := &IPAllowList{}
err := w.Update(allowList)
return w, err
}

func (w *IPAllowList) GetAllowList() string {
return w.allowList
}

func (w *IPAllowList) Update(allowListStr string) error {
var cidrs []*net.IPNet
var ips []net.IP

allowList := make([]string, 0)
if allowListStr != "" {
allowList = strings.Split(allowListStr, ",")
}

for _, item := range allowList {
_, cidrNet, err := net.ParseCIDR(item)
if err == nil {
cidrs = append(cidrs, cidrNet)
} else {
ip := net.ParseIP(item)
if ip != nil {
ips = append(ips, ip)
} else {
return fmt.Errorf("invalid allowList item: %s", item)
}
}
}

w.mu.Lock()
defer w.mu.Unlock()
w.allowList = allowListStr
w.cidrs = cidrs
w.ips = ips
w.enabled = len(cidrs) > 0 || len(ips) > 0
return nil
}

func (w *IPAllowList) IsAllowed(ip interface{}) bool {
if !w.enabled {
return true
}

var parsedIP net.IP
switch v := ip.(type) {
case string:
parsedIP = net.ParseIP(v)
case net.IP:
parsedIP = v
case netip.Addr:
parsedIP = net.IP(v.AsSlice())
default:
if str, ok := v.(string); ok {
parsedIP = net.ParseIP(str)
}
}

if parsedIP == nil {
return false
}

w.mu.RLock()
defer w.mu.RUnlock()

for _, cidr := range w.cidrs {
if cidr.Contains(parsedIP) {
return true
}
}

for _, allowedIP := range w.ips {
if parsedIP.Equal(allowedIP) {
return true
}
}
return false
}
36 changes: 36 additions & 0 deletions core/http/IPAllowList_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package http_test

import (
. "github.com/mudler/LocalAI/core/http"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("IPAllowList", func() {
It("allows all IPs when allowlist is empty", func() {
w, err := NewIPAllowList("")
Expect(err).ToNot(HaveOccurred())
Expect(w.IsAllowed("192.168.1.100")).To(BeTrue())
})

It("respects CIDRs and explicit IPs", func() {
allowList := "192.168.1.0/24,10.0.0.1,127.0.0.1"
w, err := NewIPAllowList(allowList)
Expect(err).ToNot(HaveOccurred())

cases := []struct {
ip string
expected bool
}{
{"192.168.1.100", true},
{"10.0.0.1", true},
{"127.0.0.1", true},
{"10.0.0.2", false},
{"172.16.0.1", false},
}

for _, tc := range cases {
Expect(w.IsAllowed(tc.ip)).To(Equal(tc.expected), "IP: %s", tc.ip)
}
})
})
11 changes: 11 additions & 0 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ func API(application *application.Application) (*fiber.App, error) {
router.Use(recover.New())
}

//IP restriction
router.Use(func(c *fiber.Ctx) error {
clientIP := c.IP()
if application.ApplicationConfig().IPAllowListHelper.IsAllowed(clientIP) {
return c.Next()
}
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Forbidden: your IP is not allowed",
})
})

if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
Expand Down