In my recent work, I need to use MongoDB's transaction operations, so I referenced some information to encapsulate a small component that provides the basic CRUD Repository base class and UnitOfWork work unit mode. Today, let’s briefly introduce this small component.
About MongoDB transactions
It has been four years since MongoDB fully supported multi-document transactions in version 4.2. Although we may not use MongoDB to replace traditional relational databases such as MySQL/SQL Server in our projects, we cannot deny that MongoDB has already improved in transaction capabilities. More and more mature.
In MongoDB, the so-called transaction mainly refers to the transaction of multiple documents, and its usage is similar to that of traditional relational databases. But what we need to note is that multi-document transactions can only be applied to replica sets or mongos nodes . If you only have a single mongo instance, you cannot practice multi-document transactions.
Voice-over: If you are interested in MongoDB, you may wish to read my blog series: " MongoDB Getting Started to Practical Learning Journey "
So, how to perform transaction operations quickly?
Transactions in Mongo Shell
The following demonstrates how to perform transaction submission of a multi-document operation through Mongo Shell:
var session = db.getMongo().startSession(); session.startTransaction({readConcern: { level: 'majority' },writeConcern: { w: 'majority' }}); var coll1 = session.getDatabase('students').getCollection('teams'); coll1.update({name: 'yzw-football-team'}, {$set: {members: 20}}); var coll2 = session.getDatabase('students').getCollection('records'); coll1.update({name: 'Edison'}, {$set: {gender: 'Female'}}); //Successful submission of transaction session.commitTransaction(); //Rollback failed transaction session.abortTransaction();
Conducting transactions in .NET applications
The following shows an example of conducting transactions through the MongoDB Driver in a .NET application:
using (var clientSession = mongoClient.StartSession()) { try { var contacts = clientSession.Client.GetDatabase("testDB").GetCollection<Contact>("contacts"); contacts.ReplaceOne(contact => contact.Id == "1234455", contact); var books = clientSession.Client.GetDatabase("testDB").GetCollection<Book>("books"); books.DeleteOne(book => book.Id == "1234455"); clientSession.CommitTransaction(); } catch (Exception ex) { // to do some logging clientSession.AbortTransaction(); } }
In most practical applications, we are usually accustomed to using the data warehousing (Repository) model to perform CRUD, and we are also accustomed to using the unit of work (UnitOfWork) model to coordinate multiple Repositories for transaction submission. So, how do you implement this in your own project?
After referring to some information, I implemented a basic small component myself. Let's call it: EDT.MongoProxy . Let's see how it is implemented.
Singleton MongoClient
Based on the best time of MongoDB, it is best to set up singleton injection for MongoClient, because MongoClient has been designed to be thread-safe in MongoDB.Driver and can be shared by multiple threads. This can also avoid the overhead caused by repeatedly instantiating MongoClient. Avoid poor performance in extreme situations.
For now, design a MongoDbConnection class to wrap this MongoClient, and then inject it into the IoC container in singleton mode.
public class MongoDbConnection : IMongoDbConnection { public IMongoClient DatabaseClient { get; } public string DatabaseName { get; } public MongoDbConnection(MongoDatabaseConfigs configs, IConfiguration configuration) { DatabaseClient = new MongoClient(configs.GetMongoClientSettings(configuration)); DatabaseName = configs.DatabaseName; } }
Among them, this MongoDatabaseConfigs class is mainly used to obtain the configuration in appsettings to generate the corresponding Settings of MongoClient.
/** Config Example "MongoDatabaseConfigs": { "Servers": "xxx01.edisontalk.net,xxx02.edisontalk.net,xxx03.edisontalk.net", "Port": 27017, "ReplicaSetName": "edt-replica", "DatabaseName": "EDT_Practices", "AuthDatabaseName": "admin", "ApplicationName": "Todo", "UserName": "service_testdev", "Password": "xxxxxxxxxxxxxxxxxxxxxxxx", "UseTLS": true, "AllowInsecureTLS": true, "SslCertificatePath": "/etc/pki/tls/certs/EDT_CA.cer", "UseEncryption": true } **/ public class MongoDatabaseConfigs { private const string DEFAULT_AUTH_DB = "admin"; // Default AuthDB: admin public string Servers { get; set; } public int Port { get; set; } = 27017; // Default Port: 27017 public string ReplicaSetName { get; set; } public string DatabaseName { get; set; } public string DefaultCollectionName { get; set; } = string.Empty; public string ApplicationName { get; set; } public string UserName { get; set; } public string Password { get; set; } public string AuthDatabaseName { get; set; } = DEFAULT_AUTH_DB; // Default AuthDB: admin public string CustomProperties { get; set; } = string.Empty; public bool UseTLS { get; set; } = false; public bool AllowInsecureTLS { get; set; } = true; public string SslCertificatePath { get; set; } = string.Empty; public bool StoreCertificateInKeyStore { get; set; } = false; public MongoClientSettings GetMongoClientSettings(IConfiguration configuration = null) { if (string.IsNullOrWhiteSpace(Servers)) throw new ArgumentNullException("Mongo Servers Configuration is Missing!"); if (string.IsNullOrWhiteSpace(UserName) || string.IsNullOrWhiteSpace(Password)) throw new ArgumentNullException("Mongo Account Configuration is Missing!"); // Base Configuration MongoClientSettings settings = new MongoClientSettings { ApplicationName = ApplicationName, ReplicaSetName = ReplicaSetName }; // Credential settings.Credential = MongoCredential.CreateCredential(AuthDatabaseName, UserName, Password); // Servers var mongoServers = Servers.Split(",", StringSplitOptions.RemoveEmptyEntries).ToList(); if (mongoServers.Count == 1) // Standalone { settings.Server = new MongoServerAddress(mongoServers.First(), Port); settings.DirectConnection = true; } if (mongoServers.Count > 1) // Cluster { var mongoServerAddresses = new List<MongoServerAddress>(); foreach (var mongoServer in mongoServers) { var mongoServerAddress = new MongoServerAddress(mongoServer, Port); mongoServerAddresses.Add(mongoServerAddress); } settings.Servers = mongoServerAddresses; settings.DirectConnection = false; } // SSL if (UseTLS) { settings.UseTls = true; settings.AllowInsecureTls = AllowInsecureTLS; if (string.IsNullOrWhiteSpace(SslCertificatePath)) throw new ArgumentNullException("SslCertificatePath is Missing!"); if (StoreCertificateInKeyStore) { var localTrustStore = new X509Store(StoreName.Root); var certificateCollection = new X509Certificate2Collection(); certificateCollection.Import(SslCertificatePath); try { localTrustStore.Open(OpenFlags.ReadWrite); localTrustStore.AddRange(certificateCollection); } catch (Exception ex) { throw; } finally { localTrustStore.Close(); } } var certs = new List<X509Certificate> { new X509Certificate2(SslCertificatePath) }; settings.SslSettings = new SslSettings(); settings.SslSettings.ClientCertificates = certs; settings.SslSettings.EnabledSslProtocols = System.Security.Authentication.SslProtocols.Tls13; } return settings; } }
Core part: MongoDbContext
Here we mainly follow the design of DbContext to design a MongoDbContext, which obtains the singleton MongoClient from the IoC container, encapsulates the opening and submission of transactions, and simplifies the writing of application code.
public class MongoDbContext : IMongoDbContext { private readonly IMongoDatabase _database; private readonly IMongoClient _mongoClient; private readonly IList<Func<IClientSessionHandle, Task>> _commands = new List<Func<IClientSessionHandle, Task>>(); public MongoDbContext(IMongoDbConnection dbClient) { _mongoClient = dbClient.DatabaseClient; _database = _mongoClient.GetDatabase(dbClient.DatabaseName); } public void AddCommand(Func<IClientSessionHandle, Task> func) { _commands.Add(func); } public async Task AddCommandAsync(Func<IClientSessionHandle, Task> func) { _commands.Add(func); await Task.CompletedTask; } /// <summary> /// NOTES: Only works in Cluster mode /// </summary> public int Commit(IClientSessionHandle session) { try { session.StartTransaction(); foreach (var command in _commands) { command(session); } session.CommitTransaction(); return _commands.Count; } catch (Exception ex) { session.AbortTransaction(); return 0; } finally { _commands.Clear(); } } /// <summary> /// NOTES: Only works in Cluster mode /// </summary> public async Task<int> CommitAsync(IClientSessionHandle session) { try { session.StartTransaction(); foreach (var command in _commands) { await command(session); } await session.CommitTransactionAsync(); return _commands.Count; } catch (Exception ex) { await session.AbortTransactionAsync(); return 0; } finally { _commands.Clear(); } } public IClientSessionHandle StartSession() { var session = _mongoClient.StartSession(); return session; } public async Task<IClientSessionHandle> StartSessionAsync() { var session = await _mongoClient.StartSessionAsync(); return session; } public IMongoCollection<T> GetCollection<T>(string name) { return _database.GetCollection<T>(name); } public void Dispose() { GC.SuppressFinalize(this); } }
Data warehousing: MongoRepositoryBase
In actual projects, we all hope to have a basic RepositoryBase class that encapsulates all CRUD methods. In practice, we only need to create a corresponding Repository to integrate this RepositoryBase, and there is no need to repeatedly write CRUD methods. Then, there is this MongoRepositoryBase class:
public class MongoRepositoryBase<TEntity> : IMongoRepositoryBase<TEntity> where TEntity : MongoEntityBase, new() { protected readonly IMongoDbContext _dbContext; protected readonly IMongoCollection<TEntity> _dbSet; private readonly string _collectionName; private const string _keyField = "_id"; public MongoRepositoryBase(IMongoDbContext mongoDbContext) { _dbContext = mongoDbContext; _collectionName = typeof(TEntity).GetAttributeValue((TableAttribute m) => m.Name) ?? typeof(TEntity).Name; if (string.IsNullOrWhiteSpace(_collectionName)) throw new ArgumentNullException("Mongo DatabaseName can't be NULL! Please set the attribute Table in your entity class."); _dbSet = mongoDbContext.GetCollection<TEntity>(_collectionName); } #region Create Part public async Task AddAsync(TEntity entity, IClientSessionHandle session = null) { if (session == null) await _dbSet.InsertOneAsync(entity); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.InsertOneAsync(entity)); } public async Task AddManyAsync(IEnumerable<TEntity> entityList, IClientSessionHandle session = null) { if (session == null) await _dbSet.InsertManyAsync(entityList); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.InsertManyAsync(entityList)); } #endregion # region Delete Part public async Task DeleteAsync(string id, IClientSessionHandle session = null) { if (session == null) await _dbSet.DeleteOneAsync(Builders<TEntity>.Filter.Eq(_keyField, new ObjectId(id))); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.DeleteOneAsync(Builders<TEntity>.Filter.Eq(_keyField, new ObjectId(id)))); } public async Task DeleteAsync(Expression<Func<TEntity, bool>> expression, IClientSessionHandle session = null) { if (session == null) await _dbSet.DeleteOneAsync(expression); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.DeleteOneAsync(expression)); } public async Task<DeleteResult> DeleteManyAsync(FilterDefinition<TEntity> filter, IClientSessionHandle session = null) { if (session == null) return await _dbSet.DeleteManyAsync(filter); await _dbContext.AddCommandAsync(async (session) => await _dbSet.DeleteManyAsync(filter)); return new DeleteResult.Acknowledged(10); } public async Task<DeleteResult> DeleteManyAsync(Expression<Func<TEntity, bool>> expression, IClientSessionHandle session = null) { if (session == null) return await _dbSet.DeleteManyAsync(expression); await _dbContext.AddCommandAsync(async (session) => await _dbSet.DeleteManyAsync(expression)); return new DeleteResult.Acknowledged(10); } #endregion #region Update Part public async Task UpdateAsync(TEntity entity, IClientSessionHandle session = null) { if (session == null) await _dbSet.ReplaceOneAsync(item => item.Id == entity.Id, entity); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.ReplaceOneAsync(item => item.Id == entity.Id, entity)); } public async Task UpdateAsync(Expression<Func<TEntity, bool>> expression, Expression<Action<TEntity>> entity, IClientSessionHandle session = null) { var fieldList = new List<UpdateDefinition<TEntity>>(); if (entity.Body is MemberInitExpression param) { foreach (var item in param.Bindings) { var propertyName = item.Member.Name; object propertyValue = null; if (item is not MemberAssignment memberAssignment) continue; if (memberAssignment.Expression.NodeType == ExpressionType.Constant) { if (memberAssignment.Expression is ConstantExpression constantExpression) propertyValue = constantExpression.Value; } else { propertyValue = Expression.Lambda(memberAssignment.Expression, null).Compile().DynamicInvoke(); } if (propertyName != _keyField) { fieldList.Add(Builders<TEntity>.Update.Set(propertyName, propertyValue)); } } } if (session == null) await _dbSet.UpdateOneAsync(expression, Builders<TEntity>.Update.Combine(fieldList)); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.UpdateOneAsync(expression, Builders<TEntity>.Update.Combine(fieldList))); } public async Task UpdateAsync(FilterDefinition<TEntity> filter, UpdateDefinition<TEntity> update, IClientSessionHandle session = null) { if (session == null) await _dbSet.UpdateOneAsync(filter, update); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.UpdateOneAsync(filter, update)); } public async Task UpdateManyAsync(Expression<Func<TEntity, bool>> expression, UpdateDefinition<TEntity> update, IClientSessionHandle session = null) { if (session == null) await _dbSet.UpdateManyAsync(expression, update); else await _dbContext.AddCommandAsync(async (session) => await _dbSet.UpdateManyAsync(expression, update)); } public async Task<UpdateResult> UpdateManayAsync(Dictionary<string, string> dic, FilterDefinition<TEntity> filter, IClientSessionHandle session = null) { var t = new TEntity(); // Fields to be updated var list = new List<UpdateDefinition<TEntity>>(); foreach (var item in t.GetType().GetProperties()) { if (!dic.ContainsKey(item.Name)) continue; var value = dic[item.Name]; list.Add(Builders<TEntity>.Update.Set(item.Name, value)); } var updatefilter = Builders<TEntity>.Update.Combine(list); if (session == null) return await _dbSet.UpdateManyAsync(filter, updatefilter); await _dbContext.AddCommandAsync(async (session) => await _dbSet.UpdateManyAsync(filter, updatefilter)); return new UpdateResult.Acknowledged(10, 10, null); } #endregion #region Read Part public async Task<TEntity> GetAsync(Expression<Func<TEntity, bool>> expression, bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); var queryData = await _dbSet.WithReadPreference(readPreference) .Find(expression) .FirstOrDefaultAsync(); return queryData; } public async Task<TEntity> GetAsync(string id, bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); var queryData = await _dbSet.WithReadPreference(readPreference).FindAsync(Builders<TEntity>.Filter.Eq(_keyField, new ObjectId(id))); return queryData.FirstOrDefault(); } public async Task<IEnumerable<TEntity>> GetAllAsync(bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); var queryAllData = await _dbSet.WithReadPreference(readPreference).FindAsync(Builders<TEntity>.Filter.Empty); return queryAllData.ToList(); } public async Task<long> CountAsync(Expression<Func<TEntity, bool>> expression, bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); return await _dbSet.WithReadPreference(readPreference).CountDocumentsAsync(expression); } public async Task<long> CountAsync(FilterDefinition<TEntity> filter, bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); return await _dbSet.WithReadPreference(readPreference).CountDocumentsAsync(filter); } public async Task<bool> ExistsAsync(Expression<Func<TEntity, bool>> predicate, bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); return await Task.FromResult(_dbSet.WithReadPreference(readPreference).AsQueryable().Any(predicate)); } public async Task<List<TEntity>> FindListAsync(FilterDefinition<TEntity> filter, string[]? field = null, SortDefinition<TEntity>? sort = null, bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); if (field == null || field.Length == 0) { if (sort == null) return await _dbSet.WithReadPreference(readPreference).Find(filter).ToListAsync(); return await _dbSet.WithReadPreference(readPreference).Find(filter).Sort(sort).ToListAsync(); } var fieldList = new List<ProjectionDefinition<TEntity>>(); for (int i = 0; i < field.Length; i++) { fieldList.Add(Builders<TEntity>.Projection.Include(field[i].ToString())); } var projection = Builders<TEntity>.Projection.Combine(fieldList); fieldList?.Clear(); if (sort == null) return await _dbSet.WithReadPreference(readPreference).Find(filter).Project<TEntity>(projection).ToListAsync(); return await _dbSet.WithReadPreference(readPreference).Find(filter).Sort(sort).Project<TEntity>(projection).ToListAsync(); } public async Task<List<TEntity>> FindListByPageAsync(FilterDefinition<TEntity> filter, int pageIndex, int pageSize, string[]? field = null, SortDefinition<TEntity>? sort = null, bool readFromPrimary = true) { var readPreference = GetReadPreference(readFromPrimary); if (field == null || field.Length == 0) { if (sort == null) return await _dbSet.WithReadPreference(readPreference).Find(filter).Skip((pageIndex - 1) * pageSize).Limit(pageSize).ToListAsync(); return await _dbSet.WithReadPreference(readPreference).Find(filter).Sort(sort).Skip((pageIndex - 1) * pageSize).Limit(pageSize).ToListAsync(); } var fieldList = new List<ProjectionDefinition<TEntity>>(); for (int i = 0; i < field.Length; i++) { fieldList.Add(Builders<TEntity>.Projection.Include(field[i].ToString())); } var projection = Builders<TEntity>.Projection.Combine(fieldList); fieldList?.Clear(); if (sort == null) return await _dbSet.WithReadPreference(readPreference).Find(filter).Project<TEntity>(projection).Skip((pageIndex - 1) * pageSize).Limit(pageSize).ToListAsync(); return await _dbSet.WithReadPreference(readPreference).Find(filter).Sort(sort).Project<TEntity>(projection).Skip((pageIndex - 1) * pageSize).Limit(pageSize).ToListAsync(); } #endregion #region Protected Methods protected ReadPreference GetReadPreference(bool readFromPrimary) { if (readFromPrimary) return ReadPreference.PrimaryPreferred; else return ReadPreference.SecondaryPreferred; } #endregion }
Unit of work: UnitOfWork
In actual projects, after operating on multiple Repositories, we hope to have a unified commit operation to achieve atomicity of transactions. Therefore, we can have a UnitOfWork as a proxy:
public class UnitOfWork : IUnitOfWork { private readonly IMongoDbContext _context; public UnitOfWork(IMongoDbContext context) { _context = context; } public bool SaveChanges(IClientSessionHandle session) { return _context.Commit(session) > 0; } public async Task<bool> SaveChangesAsync(IClientSessionHandle session) { return await _context.CommitAsync(session) > 0; } public IClientSessionHandle BeginTransaction() { return _context.StartSession(); } public async Task<IClientSessionHandle> BeginTransactionAsync() { return await _context.StartSessionAsync(); } public void Dispose() { _context.Dispose(); } }
Encapsulation injection: ServiceCollectionExtensions