Skip to content
1 change: 1 addition & 0 deletions internal/conf/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,5 @@ const (
PathKey
SharingIDKey
SkipHookKey
VhostPrefixKey
)
24 changes: 24 additions & 0 deletions internal/db/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ func GetSharingById(id string) (*model.SharingDB, error) {
return &s, nil
}

// GetSharingByDomain 根据绑定的域名查询 sharing 记录(用于虚拟主机能力)。
// 仅当 sharing.Domain 字段精确匹配时返回;调用方需自行判断 Disabled / Expires / Files 等有效性。
func GetSharingByDomain(domain string) (*model.SharingDB, error) {
var s model.SharingDB
if err := db.Where("domain = ?", domain).First(&s).Error; err != nil {
return nil, errors.Wrapf(err, "failed get sharing by domain")
}
return &s, nil
}

func GetSharings(pageIndex, pageSize int) (sharings []model.SharingDB, count int64, err error) {
sharingDB := db.Model(&model.SharingDB{})
if err := sharingDB.Count(&count).Error; err != nil {
Expand All @@ -38,6 +48,13 @@ func GetSharingsByCreatorId(creator uint, pageIndex, pageSize int) (sharings []m
}

func CreateSharing(s *model.SharingDB) (string, error) {
// domain 非空时做唯一性提前校验
if s.Domain != "" {
var exist model.SharingDB
if err := db.Where("domain = ?", s.Domain).First(&exist).Error; err == nil {
return "", errors.New("domain already used")
}
}
if s.ID == "" {
id := random.String(8)
for len(id) < 12 {
Expand All @@ -61,6 +78,13 @@ func CreateSharing(s *model.SharingDB) (string, error) {
}

func UpdateSharing(s *model.SharingDB) error {
// domain 非空时校验唯一性(排除自身)
if s.Domain != "" {
var exist model.SharingDB
if err := db.Where("domain = ? AND id <> ?", s.Domain, s.ID).First(&exist).Error; err == nil {
return errors.New("domain already used")
}
}
return errors.WithStack(db.Save(s).Error)
}

Expand Down
1 change: 1 addition & 0 deletions internal/db/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package db
4 changes: 4 additions & 0 deletions internal/model/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ type SharingDB struct {
Remark string `json:"remark"`
Readme string `json:"readme" gorm:"type:text"`
Header string `json:"header" gorm:"type:text"`
// Domain 绑定的域名,可为空;非空时该条记录额外作为虚拟主机参与 Host 匹配(与旧 VirtualHost.Domain 等价)。
Domain string `json:"domain" gorm:"uniqueIndex"`
// WebHosting 仅在 Domain 非空时有效;为 true 时启用 Web 托管模式(直接响应文件内容),为 false 时仅做路径重映射。
WebHosting bool `json:"web_hosting"`
Sort
}

Expand Down
1 change: 1 addition & 0 deletions internal/model/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package model
82 changes: 82 additions & 0 deletions internal/op/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
stdpath "path"
"strings"
"time"

"github.com/OpenListTeam/OpenList/v4/internal/db"
"github.com/OpenListTeam/OpenList/v4/internal/model"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/OpenListTeam/go-cache"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
)

func makeJoined(sdb []model.SharingDB) []model.Sharing {
Expand Down Expand Up @@ -42,6 +44,11 @@ func makeJoined(sdb []model.SharingDB) []model.Sharing {
var sharingCache = cache.NewMemCache(cache.WithShards[*model.Sharing](8))
var sharingG singleflight.Group[*model.Sharing]

// domainSharingCache 按虚拟主机 domain 作为 key 缓存对应的 *model.Sharing。
// 允许缓存为 nil 以实现"负缓存"防止穿透。
var domainSharingCache = cache.NewMemCache(cache.WithShards[*model.Sharing](2))
var domainSharingG singleflight.Group[*model.Sharing]

func GetSharingById(id string, refresh ...bool) (*model.Sharing, error) {
if !utils.IsBool(refresh...) {
if sharing, ok := sharingCache.Get(id); ok {
Expand Down Expand Up @@ -71,6 +78,68 @@ func GetSharingById(id string, refresh ...bool) (*model.Sharing, error) {
return sharing, err
}

// GetSharingByDomain 根据 domain 获取可用的虚拟主机 sharing(带缓存)。
// 仅当 sharing.Domain 非空、Disabled=false、Files 非空、Expires 未过期时才视为有效。
// 如果在 DB 中未找到,会负缓存 5 分钟,避免反复穿透 DB。
func GetSharingByDomain(domain string) (*model.Sharing, error) {
if domain == "" {
return nil, errors.New("empty domain")
}
if s, ok := domainSharingCache.Get(domain); ok {
if s == nil {
log.Debugf("[Sharing] domain cache hit (nil) for %q", domain)
return nil, errors.New("sharing not found by domain")
}
log.Debugf("[Sharing] domain cache hit for %q id=%s", domain, s.ID)
if !s.Valid() {
return nil, errors.New("sharing not valid")
}
return s, nil
}
sharing, err, _ := domainSharingG.Do(domain, func() (*model.Sharing, error) {
sdb, err := db.GetSharingByDomain(domain)
if err != nil {
if errors.Is(errors.Cause(err), gorm.ErrRecordNotFound) {
log.Debugf("[Sharing] domain=%q not found in db, caching nil", domain)
domainSharingCache.Set(domain, nil, cache.WithEx[*model.Sharing](time.Minute*5))
return nil, errors.New("sharing not found by domain")
}
return nil, errors.WithMessagef(err, "failed get sharing by domain [%s]", domain)
}
creator, err := GetUserById(sdb.CreatorId)
if err != nil {
return nil, errors.WithMessagef(err, "failed get sharing creator [%s]", sdb.ID)
}
var files []string
if err = utils.Json.UnmarshalFromString(sdb.FilesRaw, &files); err != nil {
files = make([]string, 0)
}
s := &model.Sharing{
SharingDB: sdb,
Files: files,
Creator: creator,
}
domainSharingCache.Set(domain, s, cache.WithEx[*model.Sharing](time.Hour))
return s, nil
})
if err != nil {
return nil, err
}
if sharing == nil || !sharing.Valid() {
return nil, errors.New("sharing not valid for domain")
}
return sharing, nil
}

// invalidateDomainCache 在创建/更新/删除记录时调用,同时传入新/旧 domain 以使两者都失效。
func invalidateDomainCache(domains ...string) {
for _, d := range domains {
if d != "" {
domainSharingCache.Del(d)
}
}
}

func GetSharings(pageIndex, pageSize int) ([]model.Sharing, int64, error) {
s, cnt, err := db.GetSharings(pageIndex, pageSize)
if err != nil {
Expand Down Expand Up @@ -118,6 +187,7 @@ func CreateSharing(sharing *model.Sharing) (id string, err error) {
if err != nil {
return "", errors.WithStack(err)
}
invalidateDomainCache(sharing.Domain)
return db.CreateSharing(sharing.SharingDB)
}

Expand All @@ -129,12 +199,24 @@ func UpdateSharing(sharing *model.Sharing, skipMarshal ...bool) (err error) {
return errors.WithStack(err)
}
}
// 读取旧记录以便同时失效旧 domain 缓存
var oldDomain string
if old, e := db.GetSharingById(sharing.ID); e == nil {
oldDomain = old.Domain
}
sharingCache.Del(sharing.ID)
invalidateDomainCache(oldDomain, sharing.Domain)
return db.UpdateSharing(sharing.SharingDB)
}

func DeleteSharing(sid string) error {
// 先读取 domain 用于失效缓存
var oldDomain string
if old, e := db.GetSharingById(sid); e == nil {
oldDomain = old.Domain
}
sharingCache.Del(sid)
invalidateDomainCache(oldDomain)
return db.DeleteSharingById(sid)
}

Expand Down
1 change: 1 addition & 0 deletions internal/op/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package op
79 changes: 78 additions & 1 deletion server/handles/fsread.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handles

import (
"fmt"
"net"
stdpath "path"
"strings"
"time"
Expand Down Expand Up @@ -69,6 +70,8 @@ func FsListSplit(c *gin.Context) {
SharingList(c, &req)
return
}
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
req.Path = applyVhostPathMapping(c, req.Path)
user := c.Request.Context().Value(conf.UserKey).(*model.User)
if user.IsGuest() && user.Disabled {
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
Expand Down Expand Up @@ -272,6 +275,11 @@ func FsGetSplit(c *gin.Context) {
SharingGet(c, &req)
return
}
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
// 同时将 vhost.Path 前缀存入 context,供 FsGet 生成 /p/ 链接时去掉前缀
var vhostPrefix string
req.Path, vhostPrefix = applyVhostPathMappingWithPrefix(c, req.Path)
common.GinWithValue(c, conf.VhostPrefixKey, vhostPrefix)
user := c.Request.Context().Value(conf.UserKey).(*model.User)
if user.IsGuest() && user.Disabled {
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
Expand Down Expand Up @@ -319,12 +327,14 @@ func FsGet(c *gin.Context, req *FsGetReq, user *model.User) {
rawURL = common.GenerateDownProxyURL(storage.GetStorage(), reqPath)
if rawURL == "" {
query := ""
// 生成 /p/ 链接时,去掉 vhost 路径前缀,保持前端看到的路径一致
downPath := stripVhostPrefix(c, reqPath)
if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) {
query = "?sign=" + sign.Sign(reqPath)
}
rawURL = fmt.Sprintf("%s/p%s%s",
common.GetApiUrl(c),
utils.EncodePath(reqPath, true),
utils.EncodePath(downPath, true),
query)
}
} else {
Expand Down Expand Up @@ -427,3 +437,70 @@ func FsOther(c *gin.Context) {
}
common.SuccessResp(c, res)
}

// applyVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则,将请求路径映射到实际路径。
func applyVhostPathMapping(c *gin.Context, reqPath string) string {
mapped, _ := applyVhostPathMappingWithPrefix(c, reqPath)
return mapped
}

// applyVhostPathMappingWithPrefix 根据请求的 Host 头匹配 sharing 中带 Domain 的虚拟主机记录,
// 将请求路径映射到 sharing.Files[0] 之下,同时返回该路径前缀(用于生成下载链接时去掉前缀)。
// 例如:sharing.Files[0]="/123pan/Downloads",reqPath="/",则返回 ("/123pan/Downloads", "/123pan/Downloads")
// 例如:sharing.Files[0]="/123pan/Downloads",reqPath="/subdir",则返回 ("/123pan/Downloads/subdir", "/123pan/Downloads")
// 如果没有匹配的虚拟主机规则,则返回 (原始路径, "")
func applyVhostPathMappingWithPrefix(c *gin.Context, reqPath string) (string, string) {
rawHost := c.Request.Host
domain := stripHostPortForVhost(rawHost)
if domain == "" {
return reqPath, ""
}
sharing, err := op.GetSharingByDomain(domain)
if err != nil || sharing == nil {
return reqPath, ""
}
if sharing.WebHosting {
// Web 托管模式不做 API 路径重映射
return reqPath, ""
}
if len(sharing.Files) == 0 {
return reqPath, ""
}
root := sharing.Files[0]
// Map request path into the sharing root and verify it does not escape via traversal.
// stdpath.Join calls Clean internally, which collapses ".." segments, so we only need
// to confirm the result still lives under root.
mapped := stdpath.Join(root, reqPath)
if !strings.HasPrefix(mapped, strings.TrimRight(root, "/")+"/") && mapped != root {
utils.Log.Warnf("[VirtualHost] path traversal rejected for API remapping: domain=%q reqPath=%q", domain, reqPath)
return reqPath, ""
}
utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped)
return mapped, root
}

Comment on lines +465 to +481
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applyVhostPathMappingWithPrefix uses path.Join(vhost.Path, reqPath), which cleans the input and can remove .. segments before user.JoinPath runs. This can bypass utils.JoinBasePath's relative-path detection (it checks for .. in the original request path) and potentially allow traversal outside the user's base path. Preserve the original .. check (e.g. by using utils.JoinBasePath-style validation for reqPath) before joining, or return an error on relative paths.

Suggested change
}
// 路径重映射:将 reqPath 拼接到 vhost.Path 后面
mapped := stdpath.Join(vhost.Path, reqPath)
utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped)
return mapped, vhost.Path
}
}
// 安全检查:在进行路径拼接之前,确保请求路径中没有 ".." 段,避免通过 path.Join 清理后绕过上层的相对路径检测。
if !isSafeVhostReqPath(reqPath) {
utils.Log.Warnf("[VirtualHost] Suspicious path detected, skip vhost remapping: domain=%q reqPath=%q", domain, reqPath)
// 返回原始路径并不携带 vhost 前缀,让后续的基础路径校验逻辑基于原始路径继续处理。
return reqPath, ""
}
// 路径重映射:将 reqPath 拼接到 vhost.Path 后面
mapped := stdpath.Join(vhost.Path, reqPath)
utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped)
return mapped, vhost.Path
}
// isSafeVhostReqPath 检查虚拟主机请求路径中是否包含 ".." 段,以防止目录遍历。
// 仅当某个路径段恰好为 ".." 时视为不安全,例如 "/a/../b"。
func isSafeVhostReqPath(p string) bool {
if p == "" {
return true
}
for _, seg := range strings.Split(p, "/") {
if seg == ".." {
return false
}
}
return true
}

Copilot uses AI. Check for mistakes.
// stripVhostPrefix 从 gin context 中取出 vhost 路径前缀,并从 path 中去掉该前缀。
// 用于生成 /p/ 下载链接时,将真实路径还原为前端看到的路径。
func stripVhostPrefix(c *gin.Context, path string) string {
prefix, ok := c.Request.Context().Value(conf.VhostPrefixKey).(string)
if !ok || prefix == "" {
return path
}
if strings.HasPrefix(path, prefix+"/") {
return path[len(prefix):]
}
if path == prefix {
return "/"
}
return path
}

// stripHostPortForVhost removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6).
func stripHostPortForVhost(host string) string {
h, _, err := net.SplitHostPort(host)
if err != nil {
// No port present; return host as-is
return host
}
return h
Comment on lines +498 to +505
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are now three near-identical helpers for stripping the port from Host (stripHostPort, stripDownHostPort, stripHostPortForVhost). This duplication makes it easy for them to drift (e.g. future normalization tweaks). Consider centralizing this logic in a shared helper (e.g. under server/common or pkg/utils) and reusing it across static/handles/middlewares.

Copilot uses AI. Check for mistakes.
}
6 changes: 6 additions & 0 deletions server/handles/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ type UpdateSharingReq struct {
Remark string `json:"remark"`
Readme string `json:"readme"`
Header string `json:"header"`
Domain string `json:"domain"`
WebHosting bool `json:"web_hosting"`
model.Sort
CreatorName string `json:"creator"`
Accessed int `json:"accessed"`
Expand Down Expand Up @@ -470,6 +472,8 @@ func UpdateSharing(c *gin.Context) {
s.Header = req.Header
s.Readme = req.Readme
s.Remark = req.Remark
s.Domain = req.Domain
s.WebHosting = req.WebHosting
s.Creator = user
if err = op.UpdateSharing(s); err != nil {
common.ErrorResp(c, err, 500)
Expand Down Expand Up @@ -528,6 +532,8 @@ func CreateSharing(c *gin.Context) {
Remark: req.Remark,
Readme: req.Readme,
Header: req.Header,
Domain: req.Domain,
WebHosting: req.WebHosting,
},
Files: req.Files,
Creator: user,
Expand Down
1 change: 1 addition & 0 deletions server/handles/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package handles
Loading
Loading