package mongo import ( "context" "errors" "os" "strings" "time" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" ) var ( mongo_uri string mongo_database string mongo_username string mongo_password string ) func init() { mongo_uri = os.Getenv("MONGO_URI") if mongo_uri == "" { mongo_uri = "mongodb://localhost:27017" } mongo_database = os.Getenv("MONGO_DB") if mongo_database == "" { mongo_database = "admin" } mongo_username = os.Getenv("MONGO_USERNAME") mongo_password = os.Getenv("MONGO_PASSWORD") } // 新连接获取 func newConnect() (*mongo.Client, error) { clientOpts := options.Client().ApplyURI(mongo_uri).SetConnectTimeout(3 * time.Second) if mongo_username != "" && mongo_password != "" { clientOpts = clientOpts.SetAuth(options.Credential{ AuthMechanism: "SCRAM-SHA-256", AuthSource: "admin", Username: mongo_username, Password: mongo_password, }) } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() return mongo.Connect(ctx, clientOpts) } func newDB(client *mongo.Client, dbname string) (*mongo.Database, error) { db := dbname if db == "" { db = mongo_database } return client.Database(db), nil } func newCollection(client *mongo.Client, dbname, colname string) (*mongo.Collection, error) { db, err := newDB(client, dbname) if err != nil { return nil, err } return db.Collection(colname), nil } // 获取默认集合,即数据库名已指定的库 func newDefaultCollection(client *mongo.Client, colname string) (*mongo.Collection, error) { db, err := newDB(client, "") if err != nil { return nil, err } return db.Collection(colname), nil } func ping() error { client, err := newConnect() if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() defer client.Disconnect(ctx) return client.Ping(ctx, readpref.Primary()) } // Collection 获取集合 func Collection(client *mongo.Client, cname string) (*mongo.Collection, error) { collection := client.Database(mongo_database).Collection(cname) return collection, nil } // CollectionMulti 支持多数据库直接获取collection func CollectionMulti(client *mongo.Client, cname string) (*mongo.Collection, error) { name := strings.Split(cname, ".") if len(name) != 2 { return nil, errors.New("collection名称不正确,请使用[database.collection]方式使用") } if name[0] == "" || name[1] == "" { return nil, errors.New("名称不能为空") } collection := client.Database(name[0]).Collection(name[1]) return collection, nil }