Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 178 additions & 21 deletions src/internal/system/update_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,103 @@ import (
"os"
"path/filepath"
"strings"
"sync"

"github.com/linuxdeepin/go-lib/strv"
"github.com/linuxdeepin/go-lib/utils"
)

var (
tempSourceDirMu sync.RWMutex
tempSourceDirPath string
)

func SetTempSourceDir(tempDir string) {
tempSourceDirMu.Lock()
defer tempSourceDirMu.Unlock()
tempSourceDirPath = tempDir
logger.Infof("SetTempSourceDir: %s", tempDir)
}

func ClearTempSourceDir() {
tempSourceDirMu.Lock()
defer tempSourceDirMu.Unlock()
tempSourceDirPath = ""
logger.Info("ClearTempSourceDir")
}

func RefreshSymlinksForSourceDir(sourceDir string) {
tempSourceDirMu.RLock()
tempDir := tempSourceDirPath
tempSourceDirMu.RUnlock()

if tempDir == "" {
return
}

files, err := os.ReadDir(tempDir)
if err != nil {
logger.Warningf("RefreshSymlinksForSourceDir: failed to read tempDir %s: %v", tempDir, err)
return
}

sourceFiles, err := os.ReadDir(sourceDir)
if err != nil {
logger.Warningf("RefreshSymlinksForSourceDir: failed to read sourceDir %s: %v", sourceDir, err)
return
}

sourceFileMap := make(map[string]bool)
for _, f := range sourceFiles {
if strings.HasSuffix(f.Name(), ".list") {
sourceFileMap[f.Name()] = true
}
}

for _, f := range files {
linkPath := filepath.Join(tempDir, f.Name())
if utils.IsSymlink(linkPath) {
targetPath, err := os.Readlink(linkPath)
if err != nil {
logger.Warningf("RefreshSymlinksForSourceDir: failed to read link %s: %v", linkPath, err)
continue
}

if strings.HasPrefix(targetPath, sourceDir) {
fileName := filepath.Base(targetPath)
newTargetPath := filepath.Join(sourceDir, fileName)

if !utils.IsFileExist(newTargetPath) {
if sourceFileMap[fileName] {
os.Remove(linkPath)
if err := os.Symlink(newTargetPath, linkPath); err != nil {
logger.Warningf("RefreshSymlinksForSourceDir: failed to create symlink: %v", err)
} else {
logger.Infof("RefreshSymlinksForSourceDir: updated symlink %s -> %s", linkPath, newTargetPath)
}
}
}
}
}
}

for _, f := range sourceFiles {
fileName := f.Name()
if !strings.HasSuffix(fileName, ".list") {
continue
}
linkPath := filepath.Join(tempDir, fileName)
if _, err := os.Lstat(linkPath); os.IsNotExist(err) {
newTargetPath := filepath.Join(sourceDir, fileName)
if err := os.Symlink(newTargetPath, linkPath); err != nil {
logger.Warningf("RefreshSymlinksForSourceDir: failed to create symlink for %s: %v", fileName, err)
} else {
logger.Infof("RefreshSymlinksForSourceDir: created symlink %s -> %s", linkPath, newTargetPath)
}
}
}
}

type UpdateType uint64

// org.deepin.upgradedelivery的ServiceStatus返回的服务状态。1为可用,2为不可用,0为未知
Expand Down Expand Up @@ -173,25 +265,22 @@ func UpdateSystemDefaultSourceDir(sourceList []string) error {
return nil
}

func UpdateP2pDefaultSourceDir(updateType UpdateType, upgradeDeliveryEnabled bool) {
func UpdateP2pDefaultSourceDir(updateType UpdateType, upgradeDeliveryEnabled bool, platformRepos []string) error {
if !upgradeDeliveryEnabled {
return
return nil
}
var sourceDir string
switch updateType {
case SystemUpdate:
sourceDir = SoftLinkSystemSourceDir
case SecurityUpdate:
sourceDir = SecuritySourceDir
logger.Infof("UpdateP2pDefaultSourceDir: updateType=%v, platformRepos=%v", updateType, platformRepos)
sourceDir := GetCategorySourceMap()[updateType]
p2pSource, err := ioutil.TempFile("/tmp", "p2pSource-*.list")
if err != nil {
return fmt.Errorf("create temp file for p2p source failed: %v", err)
}
p2pSource, err := ioutil.TempFile("/tmp", "p2pSource.*.list")
defer os.Remove(p2pSource.Name())
//从SystemSource.d或SecuritySource.d中读取每个文件内容并将协议替换成delivery协议后存放到/tmp中
//这么做为了保证替换协议的原子性
files, err := ioutil.ReadDir(sourceDir)
if err != nil {
logger.Warningf("Error writing dir: %s err:%v", sourceDir, err)
return
return fmt.Errorf("Error writing dir: %s err: %w", sourceDir, err)
}
for _, file := range files {
var content []byte
Expand All @@ -208,39 +297,103 @@ func UpdateP2pDefaultSourceDir(updateType UpdateType, upgradeDeliveryEnabled boo
}
content, err = ioutil.ReadFile(targetPath)
if err != nil {
logger.Warningf("Error reading file: %v\n", err)
return
return fmt.Errorf("error reading file: %w", err)
}
} else {
content, err = ioutil.ReadFile(filePath)
if err != nil {
logger.Warningf("Error reading file: %v\n", err)
return
return fmt.Errorf("error reading file: %w", err)
}
}
var newContent string
newContent = strings.ReplaceAll(string(content), "https://", "delivery://")
newContent := replaceMatchedReposWithDelivery(string(content), platformRepos)
_, err = p2pSource.Write([]byte(newContent))
if err != nil {
logger.Warningf("Error writing file: %v\n", err)
return
return fmt.Errorf("Error writing file: %w", err)
}
}
//所有协议均正常替换后重新创建SystemSource.d或SecuritySource.d,再讲/tmp中的文件拷贝过去
err = os.RemoveAll(sourceDir)
if err != nil {
logger.Warning(err)
return
return fmt.Errorf("failed to remove %s: %w", sourceDir, err)
}
err = os.MkdirAll(sourceDir, 0755)
if err != nil {
logger.Warning(err)
return
return fmt.Errorf("failed to create %s: %w", sourceDir, err)
}
err = utils.MoveFile(p2pSource.Name(), filepath.Join(sourceDir, filepath.Base(p2pSource.Name())))
if err != nil {
logger.Warning(err)
}
RefreshSymlinksForSourceDir(sourceDir)
return nil
}

func replaceMatchedReposWithDelivery(content string, platformRepos []string) string {
matchedURLs := make(map[string]struct{})
for _, repo := range platformRepos {
urlPath := extractURLPathFromLine(repo)
if urlPath != "" {
matchedURLs[urlPath] = struct{}{}
}
}
if len(matchedURLs) == 0 {
return content
}

var lines []string
for _, line := range strings.Split(content, "\n") {
urlPath := extractURLPathFromLine(line)
if _, ok := matchedURLs[urlPath]; ok && urlPath != "" {
lines = append(lines, replaceRepoSchemeWithDelivery(line))
continue
}
lines = append(lines, line)
}
return strings.Join(lines, "\n")
}

func extractURLPathFromLine(line string) string {
fields := strings.Fields(line)
for _, field := range fields {
if strings.Contains(field, "://") {
return extractURLPath(field)
}
}
return ""
}

func extractURLPath(urlField string) string {
idx := strings.Index(urlField, "://")
if idx == -1 {
return ""
}
rest := urlField[idx+3:]
return strings.TrimSuffix(rest, "/")
}

func replaceRepoSchemeWithDelivery(line string) string {
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "deb" {
return line
}
for i := 1; i < len(fields); i++ {
if strings.HasPrefix(fields[i], "[") {
continue
}
switch {
case strings.HasPrefix(fields[i], "https://"):
fields[i] = "delivery://" + strings.TrimSuffix(strings.TrimPrefix(fields[i], "https://"), "/")
return strings.Join(fields, " ")
case strings.HasPrefix(fields[i], "http://"):
fields[i] = "delivery://" + strings.TrimSuffix(strings.TrimPrefix(fields[i], "http://"), "/")
return strings.Join(fields, " ")
case strings.HasPrefix(fields[i], "delivery://"):
return line
}
}
return line
}

func UpdateSecurityDefaultSourceDir(sourceList []string) error {
Expand Down Expand Up @@ -406,12 +559,15 @@ func CustomSourceWrapper(updateType UpdateType, doRealAction func(path string, u
var beforeDoRealErr error
var sourceDir string
// #nosec G301
logger.Infof("sourcePathList: %v", sourcePathList)
sourceDir, beforeDoRealErr = os.MkdirTemp("/tmp", "*Source.d")
if beforeDoRealErr != nil {
logger.Warning(beforeDoRealErr)
return beforeDoRealErr
}
SetTempSourceDir(sourceDir)
unref := func() {
ClearTempSourceDir()
err := os.RemoveAll(sourceDir)
if err != nil {
logger.Warning(err)
Expand Down Expand Up @@ -449,6 +605,7 @@ func CustomSourceWrapper(updateType UpdateType, doRealAction func(path string, u
// 创建对应的软链接
for _, filePath := range allSourceFilePaths {
linkPath := filepath.Join(sourceDir, filepath.Base(filePath))
logger.Infof("filePath: %s --> linkPath: %s", filePath, linkPath)
beforeDoRealErr = os.Symlink(filePath, linkPath)
if beforeDoRealErr != nil {
return fmt.Errorf("create symlink for %q failed: %v", filePath, beforeDoRealErr)
Expand Down
114 changes: 114 additions & 0 deletions src/internal/system/update_type_delivery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// SPDX-FileCopyrightText: 2026 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

package system

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestReplaceMatchedReposWithDelivery(t *testing.T) {
localContent := strings.Join([]string{
"deb https://packages.example.com/desktop beige main",
"deb https://packages.example.com/custom beige main",
"# keep comments untouched",
"deb https://security.example.com beige-security main",
}, "\n")

platformRepos := []string{
"deb https://packages.example.com/desktop beige main",
}

got := replaceMatchedReposWithDelivery(localContent, platformRepos)
lines := strings.Split(got, "\n")

assert.Equal(t, "deb delivery://packages.example.com/desktop beige main", lines[0])
assert.Equal(t, "deb https://packages.example.com/custom beige main", lines[1])
assert.Equal(t, "# keep comments untouched", lines[2])
assert.Equal(t, "deb https://security.example.com beige-security main", lines[3])
}

func TestReplaceMatchedReposWithDeliveryRequiresExactMatch(t *testing.T) {
localContent := "deb https://packages.example.com/desktop beige main"
platformRepos := []string{
"deb http://packages.example.com/desktop beige main",
}

got := replaceMatchedReposWithDelivery(localContent, platformRepos)

assert.Equal(t, "deb delivery://packages.example.com/desktop beige main", got)
}

func TestReplaceMatchedReposWithDeliveryWithoutPlatformReposKeepsOriginal(t *testing.T) {
localContent := "deb https://packages.example.com/desktop beige main"

got := replaceMatchedReposWithDelivery(localContent, nil)

assert.Equal(t, localContent, got)
}

func TestReplaceMatchedReposWithDeliveryTrailingSlash(t *testing.T) {
testCases := []struct {
name string
localContent string
platformRepos []string
expectedResult string
}{
{
name: "platform has trailing slash, local does not",
localContent: "deb https://packages.example.com/desktop beige main",
platformRepos: []string{"deb https://packages.example.com/desktop/ beige main"},
expectedResult: "deb delivery://packages.example.com/desktop beige main",
},
{
name: "local has trailing slash, platform does not",
localContent: "deb https://packages.example.com/desktop/ beige main",
platformRepos: []string{"deb https://packages.example.com/desktop beige main"},
expectedResult: "deb delivery://packages.example.com/desktop beige main",
},
{
name: "both have trailing slash",
localContent: "deb https://packages.example.com/desktop/ beige main",
platformRepos: []string{"deb https://packages.example.com/desktop/ beige main"},
expectedResult: "deb delivery://packages.example.com/desktop beige main",
},
{
name: "neither has trailing slash",
localContent: "deb https://packages.example.com/desktop beige main",
platformRepos: []string{"deb https://packages.example.com/desktop beige main"},
expectedResult: "deb delivery://packages.example.com/desktop beige main",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := replaceMatchedReposWithDelivery(tc.localContent, tc.platformRepos)
assert.Equal(t, tc.expectedResult, got)
})
}
}

func TestExtractURLPath(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"https://packages.example.com/desktop", "packages.example.com/desktop"},
{"https://packages.example.com/desktop/", "packages.example.com/desktop"},
{"http://security.example.com/", "security.example.com"},
{"delivery://packages.example.com/desktop", "packages.example.com/desktop"},
{"no-url-here", ""},
{"", ""},
}

for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
got := extractURLPath(tc.input)
assert.Equal(t, tc.expected, got)
})
}
}
Loading
Loading