From f391935999dbfc47f8f2b813ad8794ce237bf272 Mon Sep 17 00:00:00 2001 From: zu1k Date: Tue, 24 May 2022 15:16:48 +0800 Subject: [PATCH] feat: Support update db by direct download Signed-off-by: zu1k --- internal/config/config.go | 1 + internal/db/default.go | 25 ++++++++++------ internal/db/type.go | 6 ++-- internal/db/update.go | 29 +++++++++++++++---- pkg/cdn/update.go | 18 ++++++------ pkg/dbif/db.go | 2 +- pkg/download/download.go | 25 ++++++++++++++++ .../{ip2locationdb.go => ip2location.go} | 12 ++++---- pkg/ip2region/update.go | 13 +++++---- 9 files changed, 93 insertions(+), 38 deletions(-) create mode 100644 pkg/download/download.go rename pkg/ip2location/{ip2locationdb.go => ip2location.go} (77%) diff --git a/internal/config/config.go b/internal/config/config.go index a60ee75..09c0d1a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -35,6 +35,7 @@ func ReadConfig(basePath string) { if err != nil { log.Fatalln("Config invalid:", err) } + db.NameDBMap.From(dbList) db.TypeDBMap.From(dbList) } diff --git a/internal/db/default.go b/internal/db/default.go index c15cfe0..74904e3 100644 --- a/internal/db/default.go +++ b/internal/db/default.go @@ -1,5 +1,10 @@ package db +import ( + "github.com/zu1k/nali/pkg/cdn" + "github.com/zu1k/nali/pkg/ip2region" +) + func GetDefaultDBList() List { return List{ &DB{ @@ -57,10 +62,11 @@ func GetDefaultDBList() List { NameAlias: []string{ "i2r", }, - Format: FormatIP2Region, - File: "ip2region.db", - Languages: LanguagesZH, - Types: TypesIPv4, + Format: FormatIP2Region, + File: "ip2region.db", + Languages: LanguagesZH, + Types: TypesIPv4, + DownloadUrls: ip2region.DownloadUrls, }, &DB{ Name: "ip2location", @@ -71,11 +77,12 @@ func GetDefaultDBList() List { }, &DB{ - Name: "cdn", - Format: FormatCDNSkkYml, - File: "cdn.yml", - Languages: LanguagesZH, - Types: TypesCDN, + Name: "cdn", + Format: FormatCDNSkkYml, + File: "cdn.yml", + Languages: LanguagesZH, + Types: TypesCDN, + DownloadUrls: cdn.DownloadUrls, }, } } diff --git a/internal/db/type.go b/internal/db/type.go index f1ea83c..8756494 100644 --- a/internal/db/type.go +++ b/internal/db/type.go @@ -15,12 +15,14 @@ import ( type DB struct { Name string - NameAlias []string `yaml:"name-alias,omitempty"` + NameAlias []string `yaml:"name-alias,omitempty" mapstructure:"name-alias"` Format Format File string Languages []string Types []Type + + DownloadUrls []string `yaml:"download-urls,omitempty" mapstructure:"download-urls"` } func (d *DB) get() (db dbif.DB) { @@ -43,7 +45,7 @@ func (d *DB) get() (db dbif.DB) { case FormatIP2Region: db, err = ip2region.NewIp2Region(filePath) case FormatIP2Location: - db, err = ip2locationdb.NewIP2LocationDB(filePath) + db, err = ip2location.NewIP2Location(filePath) case FormatCDNSkkYml: db, err = cdn.NewCDN(filePath) default: diff --git a/internal/db/update.go b/internal/db/update.go index 92bee06..9d31a29 100644 --- a/internal/db/update.go +++ b/internal/db/update.go @@ -6,6 +6,7 @@ import ( "time" "github.com/zu1k/nali/pkg/cdn" + "github.com/zu1k/nali/pkg/download" "github.com/zu1k/nali/pkg/ip2region" "github.com/zu1k/nali/pkg/qqwry" "github.com/zu1k/nali/pkg/zxipv6wry" @@ -38,13 +39,30 @@ var DbNameListForUpdate = []string{ func getUpdateFuncByName(name string) (func() error, string) { name = strings.TrimSpace(name) if db := getDbByName(name); db != nil { + // direct download if download-url not null + if len(db.DownloadUrls) > 0 { + return func() error { + log.Printf("正在下载最新 %s 数据库...\n", db.Name) + _, err := download.Download(db.File, db.DownloadUrls...) + if err != nil { + log.Printf("%s 数据库下载失败: %s\n", db.Name, db.File) + log.Println("error:", err) + return err + } else { + log.Printf("%s 数据库下载成功: %s\n", db.Name, db.File) + return nil + } + }, string(db.Format) + } + + // intenel download func switch db.Format { case FormatQQWry: return func() error { log.Println("正在下载最新 纯真 IPv4数据库...") _, err := qqwry.Download(getDbByName("qqwry").File) if err != nil { - log.Fatalln("数据库 QQWry 下载失败:", err) + log.Println("数据库 QQWry 下载失败:", err) } return err }, FormatQQWry @@ -53,7 +71,7 @@ func getUpdateFuncByName(name string) (func() error, string) { log.Println("正在下载最新 ZX IPv6数据库...") _, err := zxipv6wry.Download(getDbByName("zxipv6wry").File) if err != nil { - log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err) + log.Println("数据库 ZXIPv6Wry 下载失败:", err) } return err }, FormatZXIPv6Wry @@ -62,7 +80,7 @@ func getUpdateFuncByName(name string) (func() error, string) { log.Println("正在下载最新 Ip2Region 数据库...") _, err := ip2region.Download(getDbByName("ip2region").File) if err != nil { - log.Fatalln("数据库 Ip2Region 下载失败:", err) + log.Println("数据库 Ip2Region 下载失败:", err) } return err }, FormatZXIPv6Wry @@ -71,13 +89,14 @@ func getUpdateFuncByName(name string) (func() error, string) { log.Println("正在下载最新 CDN服务提供商数据库...") _, err := cdn.Download(getDbByName("cdn").File) if err != nil { - log.Fatalln("数据库 CDN 下载失败:", err) + log.Println("数据库 CDN 下载失败:", err) } return err }, FormatZXIPv6Wry default: return func() error { - log.Fatalln("不支持该类型数据库的自动更新:", db.Format) + log.Println("暂不支持该类型数据库的自动更新") + log.Println("可通过指定数据库的 download-urls 从特定链接下载数据库文件") return nil }, time.Now().String() } diff --git a/pkg/cdn/update.go b/pkg/cdn/update.go index dbfd9db..3f25089 100644 --- a/pkg/cdn/update.go +++ b/pkg/cdn/update.go @@ -6,19 +6,19 @@ import ( "github.com/zu1k/nali/pkg/common" ) -const ( - githubUrl = "https://raw.githubusercontent.com/SukkaLab/cdn/master/src/cdn.yml" - jsdelivrUrl = "https://cdn.jsdelivr.net/gh/SukkaLab/cdn/src/cdn.yml" - - githubUrl2 = "https://raw.githubusercontent.com/4ft35t/cdn/master/src/cdn.yml" - jsdelivrUrl2 = "https://cdn.jsdelivr.net/gh/4ft35t/cdn/src/cdn.yml" -) +var DownloadUrls = []string{ + "https://cdn.jsdelivr.net/gh/SukkaLab/cdn/src/cdn.yml", + "https://raw.githubusercontent.com/SukkaLab/cdn/master/src/cdn.yml", + "https://cdn.jsdelivr.net/gh/4ft35t/cdn/src/cdn.yml", + "https://raw.githubusercontent.com/4ft35t/cdn/master/src/cdn.yml", +} +// Deprecated: This will be removed from 0.5.0, use package download instead func Download(filePath ...string) (data []byte, err error) { - data, err = common.GetHttpClient().Get(jsdelivrUrl, githubUrl, jsdelivrUrl2, githubUrl2) + data, err = common.GetHttpClient().Get(DownloadUrls...) if err != nil { log.Printf("CDN数据库下载失败,请手动下载解压后保存到本地: %s \n", filePath) - log.Println("下载链接:", githubUrl) + log.Println("下载链接:", DownloadUrls) return } diff --git a/pkg/dbif/db.go b/pkg/dbif/db.go index c6a7bdf..90c4bdd 100644 --- a/pkg/dbif/db.go +++ b/pkg/dbif/db.go @@ -30,6 +30,6 @@ var ( _ DB = &ipip.IPIPFree{} _ DB = &geoip.GeoIP{} _ DB = &ip2region.Ip2Region{} - _ DB = &ip2locationdb.IP2LocationDB{} + _ DB = &ip2location.IP2Location{} _ DB = &cdn.CDN{} ) diff --git a/pkg/download/download.go b/pkg/download/download.go new file mode 100644 index 0000000..7907716 --- /dev/null +++ b/pkg/download/download.go @@ -0,0 +1,25 @@ +package download + +import ( + "log" + + "github.com/zu1k/nali/pkg/common" +) + +func Download(filePath string, urls ...string) (data []byte, err error) { + _ = urls[0] + + data, err = common.GetHttpClient().Get(urls...) + if err != nil { + log.Printf("文件下载失败,请手动下载解压后保存到本地: %s \n", filePath) + log.Println("下载链接:", urls) + return + } + + if len(filePath) == 1 { + if err := common.SaveFile(filePath, data); err == nil { + log.Println("文件下载成功:", filePath) + } + } + return +} diff --git a/pkg/ip2location/ip2locationdb.go b/pkg/ip2location/ip2location.go similarity index 77% rename from pkg/ip2location/ip2locationdb.go rename to pkg/ip2location/ip2location.go index cd9f901..1e1cd1e 100644 --- a/pkg/ip2location/ip2locationdb.go +++ b/pkg/ip2location/ip2location.go @@ -1,4 +1,4 @@ -package ip2locationdb +package ip2location import ( "errors" @@ -10,13 +10,13 @@ import ( "github.com/ip2location/ip2location-go/v9" ) -// IP2LocationDB -type IP2LocationDB struct { +// IP2Location +type IP2Location struct { db *ip2location.DB } // new IP2Location from database file -func NewIP2LocationDB(filePath string) (*IP2LocationDB, error) { +func NewIP2Location(filePath string) (*IP2Location, error) { _, err := os.Stat(filePath) if err != nil && os.IsNotExist(err) { log.Println("文件不存在,请自行下载 IP2Location 库,并保存在", filePath) @@ -27,11 +27,11 @@ func NewIP2LocationDB(filePath string) (*IP2LocationDB, error) { if err != nil { log.Fatal(err) } - return &IP2LocationDB{db: db}, nil + return &IP2Location{db: db}, nil } } -func (x IP2LocationDB) Find(query string, params ...string) (result fmt.Stringer, err error) { +func (x IP2Location) Find(query string, params ...string) (result fmt.Stringer, err error) { ip := net.ParseIP(query) if ip == nil { return nil, errors.New("Query should be valid IP") diff --git a/pkg/ip2region/update.go b/pkg/ip2region/update.go index b5a5750..cde79ab 100644 --- a/pkg/ip2region/update.go +++ b/pkg/ip2region/update.go @@ -6,16 +6,17 @@ import ( "github.com/zu1k/nali/pkg/common" ) -const ( - githubUrl = "https://raw.githubusercontent.com/lionsoul2014/ip2region/master/data/ip2region.db" - jsdelivrUrl = "https://cdn.jsdelivr.net/gh/lionsoul2014/ip2region/data/ip2region.db" -) +var DownloadUrls = []string{ + "https://cdn.jsdelivr.net/gh/lionsoul2014/ip2region/data/ip2region.db", + "https://raw.githubusercontent.com/lionsoul2014/ip2region/master/data/ip2region.db", +} +// Deprecated: This will be removed from 0.5.0, use package download instead func Download(filePath ...string) (data []byte, err error) { - data, err = common.GetHttpClient().Get(jsdelivrUrl, githubUrl) + data, err = common.GetHttpClient().Get(DownloadUrls...) if err != nil { log.Printf("CDN数据库下载失败,请手动下载解压后保存到本地: %s \n", filePath) - log.Println("下载链接:", githubUrl) + log.Println("下载链接:", DownloadUrls) return }