diff --git a/internal/conf/const.go b/internal/conf/const.go index b99d8849c..3517c2851 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -191,4 +191,5 @@ const ( PathKey SharingIDKey SkipHookKey + VhostPrefixKey ) diff --git a/internal/db/sharing.go b/internal/db/sharing.go index 8670b15f3..89d80a849 100644 --- a/internal/db/sharing.go +++ b/internal/db/sharing.go @@ -14,6 +14,16 @@ func GetSharingById(id string) (*model.SharingDB, error) { return &s, nil } +// GetSharingByDomain 根据绑定的域名查询 sharing 记录(用于虚拟主机能力)。 +// 仅当 sharing.Domain 字段精确匹配时返回;调用方需自行判断 Disabled / Expires / Files 等有效性。 +func GetSharingByDomain(domain string) (*model.SharingDB, error) { + var s model.SharingDB + if err := db.Where("domain = ?", domain).First(&s).Error; err != nil { + return nil, errors.Wrapf(err, "failed get sharing by domain") + } + return &s, nil +} + func GetSharings(pageIndex, pageSize int) (sharings []model.SharingDB, count int64, err error) { sharingDB := db.Model(&model.SharingDB{}) if err := sharingDB.Count(&count).Error; err != nil { @@ -38,6 +48,13 @@ func GetSharingsByCreatorId(creator uint, pageIndex, pageSize int) (sharings []m } func CreateSharing(s *model.SharingDB) (string, error) { + // domain 非空时做唯一性提前校验 + if s.Domain != "" { + var exist model.SharingDB + if err := db.Where("domain = ?", s.Domain).First(&exist).Error; err == nil { + return "", errors.New("domain already used") + } + } if s.ID == "" { id := random.String(8) for len(id) < 12 { @@ -61,6 +78,13 @@ func CreateSharing(s *model.SharingDB) (string, error) { } func UpdateSharing(s *model.SharingDB) error { + // domain 非空时校验唯一性(排除自身) + if s.Domain != "" { + var exist model.SharingDB + if err := db.Where("domain = ? AND id <> ?", s.Domain, s.ID).First(&exist).Error; err == nil { + return errors.New("domain already used") + } + } return errors.WithStack(db.Save(s).Error) } diff --git a/internal/db/virtual_host.go b/internal/db/virtual_host.go new file mode 100644 index 000000000..10060e9d1 --- /dev/null +++ b/internal/db/virtual_host.go @@ -0,0 +1 @@ +package db \ No newline at end of file diff --git a/internal/model/sharing.go b/internal/model/sharing.go index c5dd95e9e..4283c7f86 100644 --- a/internal/model/sharing.go +++ b/internal/model/sharing.go @@ -14,6 +14,10 @@ type SharingDB struct { Remark string `json:"remark"` Readme string `json:"readme" gorm:"type:text"` Header string `json:"header" gorm:"type:text"` + // Domain 绑定的域名,可为空;非空时该条记录额外作为虚拟主机参与 Host 匹配(与旧 VirtualHost.Domain 等价)。 + Domain string `json:"domain" gorm:"uniqueIndex"` + // WebHosting 仅在 Domain 非空时有效;为 true 时启用 Web 托管模式(直接响应文件内容),为 false 时仅做路径重映射。 + WebHosting bool `json:"web_hosting"` Sort } diff --git a/internal/model/virtual_host.go b/internal/model/virtual_host.go new file mode 100644 index 000000000..0c3b4f2ff --- /dev/null +++ b/internal/model/virtual_host.go @@ -0,0 +1 @@ +package model \ No newline at end of file diff --git a/internal/op/sharing.go b/internal/op/sharing.go index 9db51c59d..c0678a751 100644 --- a/internal/op/sharing.go +++ b/internal/op/sharing.go @@ -4,6 +4,7 @@ import ( "fmt" stdpath "path" "strings" + "time" "github.com/OpenListTeam/OpenList/v4/internal/db" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -12,6 +13,7 @@ import ( "github.com/OpenListTeam/go-cache" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "gorm.io/gorm" ) func makeJoined(sdb []model.SharingDB) []model.Sharing { @@ -42,6 +44,11 @@ func makeJoined(sdb []model.SharingDB) []model.Sharing { var sharingCache = cache.NewMemCache(cache.WithShards[*model.Sharing](8)) var sharingG singleflight.Group[*model.Sharing] +// domainSharingCache 按虚拟主机 domain 作为 key 缓存对应的 *model.Sharing。 +// 允许缓存为 nil 以实现"负缓存"防止穿透。 +var domainSharingCache = cache.NewMemCache(cache.WithShards[*model.Sharing](2)) +var domainSharingG singleflight.Group[*model.Sharing] + func GetSharingById(id string, refresh ...bool) (*model.Sharing, error) { if !utils.IsBool(refresh...) { if sharing, ok := sharingCache.Get(id); ok { @@ -71,6 +78,68 @@ func GetSharingById(id string, refresh ...bool) (*model.Sharing, error) { return sharing, err } +// GetSharingByDomain 根据 domain 获取可用的虚拟主机 sharing(带缓存)。 +// 仅当 sharing.Domain 非空、Disabled=false、Files 非空、Expires 未过期时才视为有效。 +// 如果在 DB 中未找到,会负缓存 5 分钟,避免反复穿透 DB。 +func GetSharingByDomain(domain string) (*model.Sharing, error) { + if domain == "" { + return nil, errors.New("empty domain") + } + if s, ok := domainSharingCache.Get(domain); ok { + if s == nil { + log.Debugf("[Sharing] domain cache hit (nil) for %q", domain) + return nil, errors.New("sharing not found by domain") + } + log.Debugf("[Sharing] domain cache hit for %q id=%s", domain, s.ID) + if !s.Valid() { + return nil, errors.New("sharing not valid") + } + return s, nil + } + sharing, err, _ := domainSharingG.Do(domain, func() (*model.Sharing, error) { + sdb, err := db.GetSharingByDomain(domain) + if err != nil { + if errors.Is(errors.Cause(err), gorm.ErrRecordNotFound) { + log.Debugf("[Sharing] domain=%q not found in db, caching nil", domain) + domainSharingCache.Set(domain, nil, cache.WithEx[*model.Sharing](time.Minute*5)) + return nil, errors.New("sharing not found by domain") + } + return nil, errors.WithMessagef(err, "failed get sharing by domain [%s]", domain) + } + creator, err := GetUserById(sdb.CreatorId) + if err != nil { + return nil, errors.WithMessagef(err, "failed get sharing creator [%s]", sdb.ID) + } + var files []string + if err = utils.Json.UnmarshalFromString(sdb.FilesRaw, &files); err != nil { + files = make([]string, 0) + } + s := &model.Sharing{ + SharingDB: sdb, + Files: files, + Creator: creator, + } + domainSharingCache.Set(domain, s, cache.WithEx[*model.Sharing](time.Hour)) + return s, nil + }) + if err != nil { + return nil, err + } + if sharing == nil || !sharing.Valid() { + return nil, errors.New("sharing not valid for domain") + } + return sharing, nil +} + +// invalidateDomainCache 在创建/更新/删除记录时调用,同时传入新/旧 domain 以使两者都失效。 +func invalidateDomainCache(domains ...string) { + for _, d := range domains { + if d != "" { + domainSharingCache.Del(d) + } + } +} + func GetSharings(pageIndex, pageSize int) ([]model.Sharing, int64, error) { s, cnt, err := db.GetSharings(pageIndex, pageSize) if err != nil { @@ -118,6 +187,7 @@ func CreateSharing(sharing *model.Sharing) (id string, err error) { if err != nil { return "", errors.WithStack(err) } + invalidateDomainCache(sharing.Domain) return db.CreateSharing(sharing.SharingDB) } @@ -129,12 +199,24 @@ func UpdateSharing(sharing *model.Sharing, skipMarshal ...bool) (err error) { return errors.WithStack(err) } } + // 读取旧记录以便同时失效旧 domain 缓存 + var oldDomain string + if old, e := db.GetSharingById(sharing.ID); e == nil { + oldDomain = old.Domain + } sharingCache.Del(sharing.ID) + invalidateDomainCache(oldDomain, sharing.Domain) return db.UpdateSharing(sharing.SharingDB) } func DeleteSharing(sid string) error { + // 先读取 domain 用于失效缓存 + var oldDomain string + if old, e := db.GetSharingById(sid); e == nil { + oldDomain = old.Domain + } sharingCache.Del(sid) + invalidateDomainCache(oldDomain) return db.DeleteSharingById(sid) } diff --git a/internal/op/virtual_host.go b/internal/op/virtual_host.go new file mode 100644 index 000000000..b9e21d5cf --- /dev/null +++ b/internal/op/virtual_host.go @@ -0,0 +1 @@ +package op \ No newline at end of file diff --git a/server/handles/fsread.go b/server/handles/fsread.go index a90fc1082..9468e92c0 100644 --- a/server/handles/fsread.go +++ b/server/handles/fsread.go @@ -2,6 +2,7 @@ package handles import ( "fmt" + "net" stdpath "path" "strings" "time" @@ -69,6 +70,8 @@ func FsListSplit(c *gin.Context) { SharingList(c, &req) return } + // 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径 + req.Path = applyVhostPathMapping(c, req.Path) user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() && user.Disabled { common.ErrorStrResp(c, "Guest user is disabled, login please", 401) @@ -272,6 +275,11 @@ func FsGetSplit(c *gin.Context) { SharingGet(c, &req) return } + // 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径 + // 同时将 vhost.Path 前缀存入 context,供 FsGet 生成 /p/ 链接时去掉前缀 + var vhostPrefix string + req.Path, vhostPrefix = applyVhostPathMappingWithPrefix(c, req.Path) + common.GinWithValue(c, conf.VhostPrefixKey, vhostPrefix) user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() && user.Disabled { common.ErrorStrResp(c, "Guest user is disabled, login please", 401) @@ -319,12 +327,14 @@ func FsGet(c *gin.Context, req *FsGetReq, user *model.User) { rawURL = common.GenerateDownProxyURL(storage.GetStorage(), reqPath) if rawURL == "" { query := "" + // 生成 /p/ 链接时,去掉 vhost 路径前缀,保持前端看到的路径一致 + downPath := stripVhostPrefix(c, reqPath) if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { query = "?sign=" + sign.Sign(reqPath) } rawURL = fmt.Sprintf("%s/p%s%s", common.GetApiUrl(c), - utils.EncodePath(reqPath, true), + utils.EncodePath(downPath, true), query) } } else { @@ -427,3 +437,70 @@ func FsOther(c *gin.Context) { } common.SuccessResp(c, res) } + +// applyVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则,将请求路径映射到实际路径。 +func applyVhostPathMapping(c *gin.Context, reqPath string) string { + mapped, _ := applyVhostPathMappingWithPrefix(c, reqPath) + return mapped +} + +// applyVhostPathMappingWithPrefix 根据请求的 Host 头匹配 sharing 中带 Domain 的虚拟主机记录, +// 将请求路径映射到 sharing.Files[0] 之下,同时返回该路径前缀(用于生成下载链接时去掉前缀)。 +// 例如:sharing.Files[0]="/123pan/Downloads",reqPath="/",则返回 ("/123pan/Downloads", "/123pan/Downloads") +// 例如:sharing.Files[0]="/123pan/Downloads",reqPath="/subdir",则返回 ("/123pan/Downloads/subdir", "/123pan/Downloads") +// 如果没有匹配的虚拟主机规则,则返回 (原始路径, "") +func applyVhostPathMappingWithPrefix(c *gin.Context, reqPath string) (string, string) { + rawHost := c.Request.Host + domain := stripHostPortForVhost(rawHost) + if domain == "" { + return reqPath, "" + } + sharing, err := op.GetSharingByDomain(domain) + if err != nil || sharing == nil { + return reqPath, "" + } + if sharing.WebHosting { + // Web 托管模式不做 API 路径重映射 + return reqPath, "" + } + if len(sharing.Files) == 0 { + return reqPath, "" + } + root := sharing.Files[0] + // Map request path into the sharing root and verify it does not escape via traversal. + // stdpath.Join calls Clean internally, which collapses ".." segments, so we only need + // to confirm the result still lives under root. + mapped := stdpath.Join(root, reqPath) + if !strings.HasPrefix(mapped, strings.TrimRight(root, "/")+"/") && mapped != root { + utils.Log.Warnf("[VirtualHost] path traversal rejected for API remapping: domain=%q reqPath=%q", domain, reqPath) + return reqPath, "" + } + utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped) + return mapped, root +} + +// stripVhostPrefix 从 gin context 中取出 vhost 路径前缀,并从 path 中去掉该前缀。 +// 用于生成 /p/ 下载链接时,将真实路径还原为前端看到的路径。 +func stripVhostPrefix(c *gin.Context, path string) string { + prefix, ok := c.Request.Context().Value(conf.VhostPrefixKey).(string) + if !ok || prefix == "" { + return path + } + if strings.HasPrefix(path, prefix+"/") { + return path[len(prefix):] + } + if path == prefix { + return "/" + } + return path +} + +// stripHostPortForVhost removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6). +func stripHostPortForVhost(host string) string { + h, _, err := net.SplitHostPort(host) + if err != nil { + // No port present; return host as-is + return host + } + return h +} diff --git a/server/handles/sharing.go b/server/handles/sharing.go index 43f855afb..4c39d2cb6 100644 --- a/server/handles/sharing.go +++ b/server/handles/sharing.go @@ -412,6 +412,8 @@ type UpdateSharingReq struct { Remark string `json:"remark"` Readme string `json:"readme"` Header string `json:"header"` + Domain string `json:"domain"` + WebHosting bool `json:"web_hosting"` model.Sort CreatorName string `json:"creator"` Accessed int `json:"accessed"` @@ -470,6 +472,8 @@ func UpdateSharing(c *gin.Context) { s.Header = req.Header s.Readme = req.Readme s.Remark = req.Remark + s.Domain = req.Domain + s.WebHosting = req.WebHosting s.Creator = user if err = op.UpdateSharing(s); err != nil { common.ErrorResp(c, err, 500) @@ -528,6 +532,8 @@ func CreateSharing(c *gin.Context) { Remark: req.Remark, Readme: req.Readme, Header: req.Header, + Domain: req.Domain, + WebHosting: req.WebHosting, }, Files: req.Files, Creator: user, diff --git a/server/handles/virtual_host.go b/server/handles/virtual_host.go new file mode 100644 index 000000000..f9e30ff0b --- /dev/null +++ b/server/handles/virtual_host.go @@ -0,0 +1 @@ +package handles \ No newline at end of file diff --git a/server/middlewares/down.go b/server/middlewares/down.go index c1f81b54b..6fc3150d8 100644 --- a/server/middlewares/down.go +++ b/server/middlewares/down.go @@ -1,6 +1,8 @@ package middlewares import ( + "net" + stdpath "path" "strings" "github.com/OpenListTeam/OpenList/v4/internal/conf" @@ -17,10 +19,50 @@ import ( func PathParse(c *gin.Context) { rawPath := parsePath(c.Param("path")) + // 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径 + // 例如:vhost.Path="/123pan/Downloads",rawPath="/tests.html" -> "/123pan/Downloads/tests.html" + rawPath = applyDownVhostPathMapping(c, rawPath) common.GinWithValue(c, conf.PathKey, rawPath) c.Next() } +// applyDownVhostPathMapping 根据请求的 Host 头匹配 sharing 中带 Domain 的虚拟主机记录, +// 将下载/预览路由的路径映射到虚拟主机配置的实际路径(取 sharing.Files[0])。 +// 仅在 sharing 有效(未禁用、未过期、Files 非空)且非 Web 托管模式时生效。 +func applyDownVhostPathMapping(c *gin.Context, reqPath string) string { + rawHost := c.Request.Host + domain := stripDownHostPort(rawHost) + if domain == "" { + return reqPath + } + sharing, err := op.GetSharingByDomain(domain) + if err != nil || sharing == nil { + return reqPath + } + if sharing.WebHosting { + // Web 托管模式不做下载路径重映射 + return reqPath + } + if len(sharing.Files) == 0 { + return reqPath + } + root := sharing.Files[0] + // 路径重映射:将 reqPath 拼接到 root 后面 + mapped := stdpath.Join(root, reqPath) + utils.Log.Debugf("[VirtualHost] down path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped) + return mapped +} + +// stripDownHostPort removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6). +func stripDownHostPort(host string) string { + h, _, err := net.SplitHostPort(host) + if err != nil { + // No port present; return host as-is + return host + } + return h +} + func Down(verifyFunc func(string, string) error) func(c *gin.Context) { return func(c *gin.Context) { rawPath := c.Request.Context().Value(conf.PathKey).(string) diff --git a/server/middlewares/virtual_host.go b/server/middlewares/virtual_host.go new file mode 100644 index 000000000..3372e400f --- /dev/null +++ b/server/middlewares/virtual_host.go @@ -0,0 +1,4 @@ +package middlewares + +// Note: Virtual host resolution is handled by existing handlers/middlewares. +// This file intentionally contains no additional code to avoid unused/dead middleware. diff --git a/server/static/static.go b/server/static/static.go index 29f97ff74..dbac8c288 100644 --- a/server/static/static.go +++ b/server/static/static.go @@ -5,16 +5,23 @@ import ( "errors" "fmt" "io" - "io/fs" + iofs "io/fs" + stdnet "net" "net/http" "os" + stdpath "path" "strings" "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/conf" + internalfs "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/public" + "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/gin-gonic/gin" ) @@ -32,12 +39,12 @@ type Manifest struct { Icons []ManifestIcon `json:"icons"` } -var static fs.FS +var static iofs.FS func initStatic() { utils.Log.Debug("Initializing static file system...") if conf.Conf.DistDir == "" { - dist, err := fs.Sub(public.Public, "dist") + dist, err := iofs.Sub(public.Public, "dist") if err != nil { utils.Log.Fatalf("failed to read dist dir: %v", err) } @@ -76,7 +83,7 @@ func initIndex(siteConfig SiteConfig) { utils.Log.Debug("Reading index.html from static files system...") indexFile, err := static.Open("index.html") if err != nil { - if errors.Is(err, fs.ErrNotExist) { + if errors.Is(err, iofs.ErrNotExist) { utils.Log.Fatalf("index.html not exist, you may forget to put dist of frontend to public/dist") } utils.Log.Fatalf("failed to read index.html: %v", err) @@ -98,9 +105,9 @@ func initIndex(siteConfig SiteConfig) { manifestPath = siteConfig.BasePath + "/manifest.json" } replaceMap := map[string]string{ - "cdn: undefined": fmt.Sprintf("cdn: '%s'", siteConfig.Cdn), - "base_path: undefined": fmt.Sprintf("base_path: '%s'", siteConfig.BasePath), - `href="/manifest.json"`: fmt.Sprintf(`href="%s"`, manifestPath), + "cdn: undefined": fmt.Sprintf("cdn: '%s'", siteConfig.Cdn), + "base_path: undefined": fmt.Sprintf("base_path: '%s'", siteConfig.BasePath), + `href="/manifest.json"`: fmt.Sprintf(`href="%s"`, manifestPath), } conf.RawIndexHtml = replaceStrings(conf.RawIndexHtml, replaceMap) UpdateIndex() @@ -134,10 +141,10 @@ func UpdateIndex() { func ManifestJSON(c *gin.Context) { // Get site configuration to ensure consistent base path handling siteConfig := getSiteConfig() - + // Get site title from settings siteTitle := setting.GetStr(conf.SiteTitle) - + // Get logo from settings, use the first line (light theme logo) logoSetting := setting.GetStr(conf.Logo) logoUrl := strings.Split(logoSetting, "\n")[0] @@ -167,7 +174,7 @@ func ManifestJSON(c *gin.Context) { c.Header("Content-Type", "application/json") c.Header("Cache-Control", "public, max-age=3600") // cache for 1 hour - + if err := json.NewEncoder(c.Writer).Encode(manifest); err != nil { utils.Log.Errorf("Failed to encode manifest.json: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate manifest"}) @@ -181,7 +188,7 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { initStatic() initIndex(siteConfig) folders := []string{"assets", "images", "streamer", "static"} - + if conf.Conf.Cdn == "" { utils.Log.Debug("Setting up static file serving...") r.Use(func(c *gin.Context) { @@ -192,7 +199,7 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { } }) for _, folder := range folders { - sub, err := fs.Sub(static, folder) + sub, err := iofs.Sub(static, folder) if err != nil { utils.Log.Fatalf("can't find folder: %s", folder) } @@ -210,7 +217,49 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { } utils.Log.Debug("Setting up catch-all route...") - noRoute(func(c *gin.Context) { + + // virtualHostHandler 处理虚拟主机 Web 托管,以及默认的前端 SPA 路由 + virtualHostHandler := func(c *gin.Context) { + // 直接从 Host 头解析域名,检查是否匹配 sharing 中的虚拟主机记录 + rawHost := c.Request.Host + domain := stripHostPort(rawHost) + utils.Log.Debugf("[VirtualHost] handler triggered: method=%s path=%s host=%q domain=%q", + c.Request.Method, c.Request.URL.Path, rawHost, domain) + if domain != "" { + sharing, err := op.GetSharingByDomain(domain) + if err != nil { + utils.Log.Debugf("[VirtualHost] domain=%q not matched any sharing: %v", domain, err) + } else if sharing != nil && len(sharing.Files) > 0 { + utils.Log.Debugf("[VirtualHost] domain=%q matched sharing: id=%s web_hosting=%v root=%q", + domain, sharing.ID, sharing.WebHosting, sharing.Files[0]) + if sharing.WebHosting { + // Web 托管模式:直接返回文件内容 + // 注入 guest 用户到 context,供 internalfs.Get/Link 权限检查使用 + guest, guestErr := op.GetGuest() + if guestErr != nil { + utils.Log.Errorf("[VirtualHost] failed to get guest user: %v", guestErr) + c.Status(http.StatusInternalServerError) + return + } + common.GinWithValue(c, conf.UserKey, guest) + if handleWebHosting(c, sharing) { + return + } + } else { + // 路径重映射模式(伪静态):直接返回正常的 SPA 页面 + // 地址栏保持不变,面包屑显示用户访问的路径 + // 实际的路径映射由后端 API(fs/list、fs/get)在处理请求时完成 + utils.Log.Debugf("[VirtualHost] path remapping mode: serving SPA for domain=%q path=%q", domain, c.Request.URL.Path) + c.Header("Content-Type", "text/html") + c.Status(200) + _, _ = c.Writer.WriteString(conf.IndexHtml) + c.Writer.Flush() + c.Writer.WriteHeaderNow() + return + } + } + } + if c.Request.Method != "GET" && c.Request.Method != "POST" { c.Status(405) return @@ -224,5 +273,179 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { } c.Writer.Flush() c.Writer.WriteHeaderNow() + } + + // 显式注册根路径路由,确保 GET / 能被正确处理 + // gin 的 NoRoute 不会触发已注册路由前缀下的 GET / + r.GET("/", virtualHostHandler) + r.POST("/", virtualHostHandler) + // NoRoute 处理其他所有未匹配路径(如 /@manage、/d/... 等 SPA 路由) + noRoute(virtualHostHandler) +} + +// handleWebHosting 处理虚拟主机(sharing)的 Web 托管请求 +// 直接将文件内容返回给客户端,而不是走前端 SPA 路由 +// 返回 true 表示已处理,false 表示未处理(继续走默认逻辑) +func handleWebHosting(c *gin.Context, sharing *model.Sharing) bool { + if c.Request.Method != "GET" && c.Request.Method != "HEAD" { + utils.Log.Debugf("[VirtualHost] skip: method=%s not allowed for web hosting", c.Request.Method) + return false + } + if len(sharing.Files) == 0 { + utils.Log.Debugf("[VirtualHost] skip: sharing has no files") + return false + } + root := sharing.Files[0] + + reqPath := c.Request.URL.Path + // Map request path into the sharing root and verify it does not escape via traversal. + // stdpath.Join calls Clean internally, which collapses ".." segments, so we only need + // to confirm the result still lives under root. + filePath := stdpath.Join(root, reqPath) + if !strings.HasPrefix(filePath, strings.TrimRight(root, "/")+"/") && filePath != root { + utils.Log.Warnf("[VirtualHost] path traversal rejected: root=%q reqPath=%q", root, reqPath) + c.Status(http.StatusBadRequest) + return false + } + utils.Log.Debugf("[VirtualHost] handleWebHosting: reqPath=%q -> filePath=%q", reqPath, filePath) + + // 尝试获取文件 + obj, err := internalfs.Get(c.Request.Context(), filePath, &internalfs.GetArgs{NoLog: true}) + if err == nil && !obj.IsDir() { + // 找到文件,直接代理返回 + utils.Log.Debugf("[VirtualHost] serving file: %q", filePath) + serveWebHostingFile(c, filePath, obj.GetName()) + return true + } + utils.Log.Debugf("[VirtualHost] file not found or is dir at %q: %v", filePath, err) + + // 如果是目录或未找到,尝试 index.html + indexPath := stdpath.Join(filePath, "index.html") + obj, err = internalfs.Get(c.Request.Context(), indexPath, &internalfs.GetArgs{NoLog: true}) + if err == nil && !obj.IsDir() { + utils.Log.Debugf("[VirtualHost] serving index.html: %q", indexPath) + serveWebHostingFile(c, indexPath, "index.html") + return true + } + utils.Log.Debugf("[VirtualHost] index.html not found at %q: %v", indexPath, err) + + // 尝试 .html(SPA 友好路由) + if stdpath.Ext(reqPath) == "" && reqPath != "/" { + htmlPath := stdpath.Join(root, reqPath+".html") + obj, err = internalfs.Get(c.Request.Context(), htmlPath, &internalfs.GetArgs{NoLog: true}) + if err == nil && !obj.IsDir() { + utils.Log.Debugf("[VirtualHost] serving .html fallback: %q", htmlPath) + serveWebHostingFile(c, htmlPath, stdpath.Base(htmlPath)) + return true + } + utils.Log.Debugf("[VirtualHost] .html fallback not found at %q: %v", htmlPath, err) + } + + utils.Log.Debugf("[VirtualHost] no file matched for reqPath=%q, falling through", reqPath) + return false +} + +// serveWebHostingFile 通过代理方式直接返回文件内容 +func serveWebHostingFile(c *gin.Context, filePath, filename string) { + link, file, err := internalfs.Link(c.Request.Context(), filePath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, }) + if err != nil { + utils.Log.Errorf("web hosting: failed to get link for %s: %v", filePath, err) + c.Status(http.StatusInternalServerError) + return + } + defer link.Close() + + // 根据文件扩展名确定正确的 Content-Type + ext := strings.ToLower(stdpath.Ext(filename)) + contentType := mimeTypeByExt(ext) + + // 使用包装的 ResponseWriter,在 WriteHeader 时强制覆盖 Content-Type 和 Content-Disposition + // 这样即使 Proxy 内部的 maps.Copy 将上游响应头复制进来,我们也能在最终发送前覆盖 + wrapped := &forceContentTypeWriter{ + ResponseWriter: c.Writer, + contentType: contentType, + contentDisp: "inline", + } + + // 同时注入到 link.Header,供 attachHeader 路径(RangeReader/Concurrency 模式)使用 + if link.Header == nil { + link.Header = make(http.Header) + } + link.Header.Set("Content-Type", contentType) + link.Header.Set("Content-Disposition", "inline") + + // 使用通用代理函数处理文件传输 + if err := common.Proxy(wrapped, c.Request, link, file); err != nil { + utils.Log.Errorf("web hosting: proxy error for %s: %v", filePath, err) + } +} + +// forceContentTypeWriter 包装 http.ResponseWriter, +// 在 WriteHeader 时强制覆盖 Content-Type 和 Content-Disposition, +// 确保 HTML 等文件以正确类型返回而不是被浏览器下载 +type forceContentTypeWriter struct { + http.ResponseWriter + contentType string + contentDisp string +} + +func (w *forceContentTypeWriter) WriteHeader(statusCode int) { + w.ResponseWriter.Header().Set("Content-Type", w.contentType) + w.ResponseWriter.Header().Set("Content-Disposition", w.contentDisp) + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *forceContentTypeWriter) Write(b []byte) (int, error) { + return w.ResponseWriter.Write(b) +} + +// mimeTypeByExt 根据文件扩展名返回 MIME 类型 +func mimeTypeByExt(ext string) string { + switch ext { + case ".html", ".htm": + return "text/html; charset=utf-8" + case ".css": + return "text/css; charset=utf-8" + case ".js", ".mjs": + return "application/javascript; charset=utf-8" + case ".json": + return "application/json; charset=utf-8" + case ".xml": + return "application/xml; charset=utf-8" + case ".svg": + return "image/svg+xml" + case ".png": + return "image/png" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".gif": + return "image/gif" + case ".webp": + return "image/webp" + case ".ico": + return "image/x-icon" + case ".woff": + return "font/woff" + case ".woff2": + return "font/woff2" + case ".ttf": + return "font/ttf" + case ".txt": + return "text/plain; charset=utf-8" + default: + return "application/octet-stream" + } +} + +// stripHostPort removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6). +func stripHostPort(host string) string { + h, _, err := stdnet.SplitHostPort(host) + if err != nil { + // No port present; return host as-is + return host + } + return h }