main.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package main
  2. import (
  3. "encoding/json"
  4. "flag"
  5. "fmt"
  6. "go-msa-auth/config"
  7. "go-msa-auth/internal/models"
  8. "go-msa-auth/server"
  9. "io/fs"
  10. "log"
  11. "os"
  12. "path/filepath"
  13. )
  14. func main() {
  15. // 检查子命令
  16. if len(os.Args) < 2 {
  17. fmt.Println("Expected 'server' or 'import' subcommands")
  18. os.Exit(1)
  19. }
  20. switch os.Args[1] {
  21. case "server":
  22. runServer()
  23. case "import":
  24. runImport()
  25. default:
  26. fmt.Println("Unknown subcommand:", os.Args[1])
  27. os.Exit(1)
  28. }
  29. }
  30. func runServer() {
  31. // 启动服务器逻辑
  32. config.InitConfig()
  33. if err := server.InitServer(); err != nil {
  34. log.Fatal("Failed to start server:", err)
  35. }
  36. }
  37. func runImport() {
  38. // 定义命令行参数
  39. importCmd := flag.NewFlagSet("import", flag.ExitOnError)
  40. dirPath := importCmd.String("dir", "./data/json_files", "Path to the directory containing JSON files")
  41. if err := importCmd.Parse(os.Args[2:]); err != nil {
  42. log.Fatal("Failed to parse flags:", err)
  43. }
  44. // 检查目录是否存在
  45. if _, err := os.Stat(*dirPath); os.IsNotExist(err) {
  46. log.Fatalf("The specified directory does not exist: %s", *dirPath)
  47. }
  48. // 初始化数据库
  49. models.InitDatabase("auth.db")
  50. // 遍历目录中的 JSON 文件
  51. err := filepath.WalkDir(*dirPath, func(path string, d fs.DirEntry, err error) error {
  52. if err != nil {
  53. return err
  54. }
  55. // 只处理 .json 文件
  56. if !d.IsDir() && filepath.Ext(path) == ".json" {
  57. log.Printf("Processing file: %s", path)
  58. if err := importJSONFile(path); err != nil {
  59. log.Printf("Failed to import file %s: %v", path, err)
  60. } else {
  61. log.Printf("Successfully imported file: %s", path)
  62. }
  63. }
  64. return nil
  65. })
  66. if err != nil {
  67. log.Fatalf("Error while traversing directory: %v", err)
  68. }
  69. log.Println("All files have been processed!")
  70. }
  71. // importJSONFile 处理单个 JSON 文件的导入逻辑
  72. func importJSONFile(filePath string) error {
  73. // 读取 JSON 文件
  74. data, err := os.ReadFile(filePath)
  75. if err != nil {
  76. return fmt.Errorf("failed to read file %s: %w", filePath, err)
  77. }
  78. // 解析 JSON 文件内容
  79. var importData models.LicenseImportData
  80. if err := json.Unmarshal(data, &importData); err != nil {
  81. return fmt.Errorf("failed to parse file %s: %w", filePath, err)
  82. }
  83. // 打印导入信息
  84. log.Printf("Importing licenses for app: %s, name: %s, expire_time: %s",
  85. importData.AppID, importData.Name, importData.ExpireTime)
  86. // 遍历授权码并写入数据库
  87. for _, license := range importData.Licenses {
  88. authCode := &models.AuthCode{
  89. UUID: license.UUID,
  90. DeviceID: license.LicenseFile, // 假设 LicenseFile 存储为 DeviceID
  91. }
  92. // 保存到数据库
  93. if err := models.SaveAuthCode(authCode); err != nil {
  94. log.Printf("Failed to save license with UUID %s: %v", license.UUID, err)
  95. } else {
  96. log.Printf("Successfully imported license: %s", license.UUID)
  97. }
  98. }
  99. return nil
  100. }