Unit of Work (Transaction Manager)

Unit of Work pattern helps maintain data consistency by grouping related database operations into a single atomic unit, preventing partial updates and maintaining referential integrity.

Applicability

  • You need to perform multiple related database operations that must all succeed or fail together
  • You want to avoid scattered transaction management code across your application
  • You need to coordinate operations across multiple repositories within a single transaction
  • You want to decouple business logic from transaction management details
  • You’re implementing complex workflows that involve multiple database tables

Example: Standard library (database/sql)

Wrap BeginTx, Commit, and Rollback once, then pass work as a function so callers keep business logic in one place and transactions stay consistent.

package main

import (
	"context"
	"database/sql"
)

// RunInTx executes fn inside a single transaction. If fn returns an error, the
// transaction is rolled back; otherwise it is committed.
func RunInTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error {
	tx, err := db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}
	defer func() { _ = tx.Rollback() }() // no-op after successful Commit
	if err := fn(tx); err != nil {
		return err
	}
	return tx.Commit()
}

// PlaceOrder inserts a parent row and a child row atomically.
func PlaceOrder(ctx context.Context, db *sql.DB, orderID, itemID string) error {
	return RunInTx(ctx, db, func(tx *sql.Tx) error {
		if _, err := tx.ExecContext(ctx,
			`INSERT INTO orders (id) VALUES (?)`, orderID); err != nil {
			return err
		}
		_, err := tx.ExecContext(ctx,
			`INSERT INTO order_items (order_id, id) VALUES (?, ?)`, orderID, itemID)
		return err
	})
}

Example: Ent (code-generated client)

More advanced: the snippet below is a fuller, Ent-specific setup—a UnitOfWork interface, transactional *ent.Client in context (GetClient), panic-safe rollback, and a service that runs several repository calls in one transaction.

package main

import (
	"context"
	"fmt"
	"log/slog"

	ent "example.com/yourapp/ent" // your generated Ent client package
)

// UnitOfWork is the transaction boundary.
type UnitOfWork interface {
	Do(ctx context.Context, fn func(ctx context.Context) error) error
}

type txKey struct{}

// entUnitOfWork implements UnitOfWork for *ent.Client.
type entUnitOfWork struct {
	client *ent.Client
}

func NewEntUnitOfWork(c *ent.Client) UnitOfWork {
	return &entUnitOfWork{client: c}
}

func (u *entUnitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error {
	tx, err := u.client.Tx(ctx)
	if err != nil {
		return fmt.Errorf("start transaction: %w", err)
	}

	defer func() {
		if v := recover(); v != nil {
			_ = tx.Rollback()
			panic(v)
		}
	}()

	txCtx := context.WithValue(ctx, txKey{}, tx.Client())

	if err := fn(txCtx); err != nil {
		if rerr := tx.Rollback(); rerr != nil {
			return fmt.Errorf("rollback: %w (original: %v)", rerr, err)
		}
		return err
	}

	if err := tx.Commit(); err != nil {
		return fmt.Errorf("commit: %w", err)
	}
	return nil
}

// GetClient returns the transactional *ent.Client from ctx inside Do, or defaultClient otherwise.
func GetClient(ctx context.Context, defaultClient *ent.Client) *ent.Client {
	if c, ok := ctx.Value(txKey{}).(*ent.Client); ok {
		return c
	}
	return defaultClient
}

type UserService struct {
	uow               UnitOfWork
	logger            *slog.Logger
	commentRepository CommentRepository
	postRepository    PostRepository
	sessionRepository SessionRepository
	userRepository    UserRepository
}

// Delete removes the user and related data atomically (adjust order to your schema).
func (s *UserService) Delete(ctx context.Context, userID int) error {
	return s.uow.Do(ctx, func(txCtx context.Context) error {
		err := s.commentRepository.DeleteByUserID(txCtx, userID)
		if err != nil {
			s.logger.Error("delete comments by user",
				"error", err,
				"userID", userID,
			)
			return err
		}

		err = s.postRepository.DeleteByAuthorID(txCtx, userID)
		if err != nil {
			s.logger.Error("delete posts by author",
				"error", err,
				"userID", userID,
			)
			return err
		}

		err = s.sessionRepository.DeleteByUserID(txCtx, userID)
		if err != nil {
			s.logger.Error("delete sessions for user",
				"error", err,
				"userID", userID,
			)
			return err
		}

		err = s.userRepository.Delete(txCtx, userID)
		if err != nil {
			s.logger.Error("delete user",
				"error", err,
				"userID", userID,
			)
			return err
		}

		return nil
	})
}

// Implementations use GetClient(txCtx, appClient) and Ent builders (e.g. client.Comment(), client.Post()).
type CommentRepository interface {
	DeleteByUserID(txCtx context.Context, userID int) error
}
type PostRepository interface {
	DeleteByAuthorID(txCtx context.Context, userID int) error
}
type SessionRepository interface {
	DeleteByUserID(txCtx context.Context, userID int) error
}
type UserRepository interface {
	Delete(txCtx context.Context, userID int) error
}

This site is open source! You can contribute or suggest changes by editing the GitHub repository.
Copyright © 2026 Viacheslav Demianov. Distributed by an MIT license.