Dark Mode

How to Create DB Migration Tool in Go from Scratch

db migrations in go

It is true that a backend application can be written in Go using nothing but only the standard libraries. It is also true that you will be missing out on a certain framework features and you have to build it on your own if you need it.

One such feature is DB Migrations. This article will show you how to create a simple DB Migration tool in Go from scratch. Let’s get started.

Of course, libraries like go-migrate and goose can help you do DB Migrations in Go. But, it uses raw .sql files for migrations, this requires us to copy a bunch of .sql files along with the binary file.

Step 1: Connecting to Database

As the first step, we will write a simple function that connects to the database and return the connection.

package app

import (
	"database/sql"
	"fmt"

        _ "github.com/go-sql-driver/mysql"
)

// NewDB .
func NewDB() *sql.DB {
	fmt.Println("Connecting to MySQL database...")

	db, err := sql.Open("mysql", "root:welcome@tcp(127.0.0.1:3306)/migrationtest")
	if err != nil {
		fmt.Println("Unable to connect to database", err.Error())
		return nil
	}

	if err := db.Ping(); err != nil {
		fmt.Println("Unable to connect to database", err.Error())
		return nil
	}

	fmt.Println("Database connected!")

	return db
}

Also Read: Different Ways to Pass Database Connection into Controllers in Golang

Step 2: Define Migration and Migrator

We will start writing our migration tool by defining a struct for a migration.

A migration will have a version, a up function to upgrade the schema, and a down function to downgrade the schema. We will also have a local field done to identify whether the migration has been executed or not.

package migrations

type Migration struct {
	Version string
	Up      func(*sql.Tx) error
	Down    func(*sql.Tx) error

	done bool
}

Next, we will define a struct Migrator for the migrator. We will also create a global variable of type Migrator.

The migrator will have a collection of Migrations. We will store the list of migration versions in an ordered array and the migrations in a hash with the version as the key.

package migrations

// Code removed for brevity

type Migrator struct {
	db         *sql.DB
	Versions   []string
	Migrations map[string]*Migration
}

var migrator = &Migrator{
	Versions:   []string{},
	Migrations: map[string]*Migration{},
}

Step 3: Generating New Migration

Before generating migrations, let us add a method to *Migrator that will allow us to add new *Migration to it.

package migrations

// Code removed for brevity

func (m *Migrator) AddMigration(mg *Migration) {
	// Add the migration to the hash with version as key
	m.Migrations[mg.Version] = mg

	// Insert version into versions array using insertion sort
	index := 0
	for index < len(m.Versions) {
		if m.Versions[index] > mg.Version {
			break
		}
		index++
	}

	m.Versions = append(m.Versions, mg.Version)
	copy(m.Versions[index+1:], m.Versions[index:])
	m.Versions[index] = mg.Version
}

We will generate new migration files based on a template. The template will have a *Migration and a init() that will push the migration into the migrator using the AddMigration method that we just added.

package migrations

import "database/sql"

func init() {
	migrator.AddMigration(&Migration{
		Version: "{{.Version}}",
		Up:      mig_{{.Version}}_{{.Name}}_up,
		Down:    mig_{{.Version}}_{{.Name}}_down,
	})
}

func mig_{{.Version}}_{{.Name}}_up(tx *sql.Tx) error {
	return nil
}

func mig_{{.Version}}_{{.Name}}_down(tx *sql.Tx) error {
	return nil
}

Once we have the template, we can write a function that will execute the template and create a new file out of it.

package migrations 

// Code removed for brevity

func Create(name string) error {
	version := time.Now().Format("20060102150405")

	in := struct {
		Version string
		Name    string
	}{
		Version: version,
		Name:    name,
	}

	var out bytes.Buffer

	t := template.Must(template.ParseFiles("./migrations/template.txt"))
	err := t.Execute(&out, in)
	if err != nil {
		return errors.New("Unable to execute template:" + err.Error())
	}

	f, err := os.Create(fmt.Sprintf("./migrations/%s_%s.go", version, name))
	if err != nil {
		return errors.New("Unable to create migration file:" + err.Error())
	}
	defer f.Close()

	if _, err := f.WriteString(out.String()); err != nil {
		return errors.New("Unable to write to migration file:" + err.Error())
	}

	fmt.Println("Generated new migration files...", f.Name())
	return nil
}

That is all we need to generate new migration files.

Step 4: Storing and Retrieving Migration Status

We need a way to find out the current state of the database – which migrations have been executed already and which are not yet executed.

For this, we will create a new table named schema_migrations. This table will have just one column – version. When we run a up migration we will insert the version of the migration into this table and when we run a down migration we will delete the version of the migration from this table.

While initializing our migrator, we can read this table and mark the migrations as done using the done flag that we have in the Migration struct.

package migrations 

// Code removed for brevity

func Init(db *sql.DB) (*Migrator, error) {
	migrator.db = db

	// Create `schema_migrations` table to remember which migrations were executed.
	if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
		version varchar(255)
	);`); err != nil {
		fmt.Println("Unable to create `schema_migrations` table", err)
		return migrator, err
	}

	// Find out all the executed migrations
	rows, err := db.Query("SELECT version FROM `schema_migrations`;")
	if err != nil {
		return migrator, err
	}

	defer rows.Close()

	// Mark the migrations as Done if it is already executed
	for rows.Next() {
		var version string
		err := rows.Scan(&version)
		if err != nil {
			return migrator, err
		}

		if migrator.Migrations[version] != nil {
			migrator.Migrations[version].done = true
		}
	}

	return migrator, err
}

Step 5: Running Up Migrations

Now that we have a mechanism to generate and add new migrations and a mechanism to find the current state of the database the next step is to run the migrations itself.

We can add a new method Up on the Migrator struct that will run all the pending migrations by default and take an optional parameter step to limit the number of migrations to run.

Since our migrator has the list of migrations in an ordered array we can simply loop through the array and execute the ones that are not yet executed.

It is important to run the migrations inside a SQL transaction so that we can rollback in case of any error.

Every time a migration runs successfully, we will insert its version into the schema_migrations table.

package migrations 

// Code removed for brevity

func (m *Migrator) Up(step int) error {
	tx, err := m.db.BeginTx(context.TODO(), &sql.TxOptions{})
	if err != nil {
		return err
	}

	count := 0
	for _, v := range m.Versions {
		if step > 0 && count == step {
			break
		}

		mg := m.Migrations[v]

		if mg.done {
			continue
		}

		fmt.Println("Running migration", mg.Version)
		if err := mg.Up(tx); err != nil {
			tx.Rollback()
			return err
		}

		if _, err := tx.Exec("INSERT INTO `schema_migrations` VALUES(?)", mg.Version); err != nil {
			tx.Rollback()
			return err
		}
		fmt.Println("Finished running migration", mg.Version)

		count++
	}

	tx.Commit()

	return nil
}

The above code runs all the migrations in a single SQL transaction, if you want you can choose to run each migration in its own transaction.

Step 6: Running Down Migrations

To run the down migrations we can add a new method Down on the Migrator struct that will revert all migrations by default and take an optional parameter step to limit the number of migrations to revert.

Since our migrator has the list of migrations in an ordered array we can simply reverse the array and loop through it and revert the ones that are already executed.

Once again we will run these in a SQL transaction.

Every time a migration runs successfully, we will delete its version from the schema_migrations table.

package migrations

// Code removed for brevity

func (m *Migrator) Down(step int) error {
	tx, err := m.db.BeginTx(context.TODO(), &sql.TxOptions{})
	if err != nil {
		return err
	}

	count := 0
	for _, v := range reverse(m.Versions) {
		if step > 0 && count == step {
			break
		}

		mg := m.Migrations[v]

		if !mg.done {
			continue
		}

		fmt.Println("Reverting Migration", mg.Version)
		if err := mg.Down(tx); err != nil {
			tx.Rollback()
			return err
		}

		if _, err := tx.Exec("DELETE FROM `schema_migrations` WHERE version = ?", mg.Version); err != nil {
			tx.Rollback()
			return err
		}
		fmt.Println("Finished reverting migration", mg.Version)

		count++
	}

	tx.Commit()

	return nil
}

func reverse(arr []string) []string {
	for i := 0; i < len(arr)/2; i++ {
		j := len(arr) - i - 1
		arr[i], arr[j] = arr[j], arr[i]
	}
	return arr
}

Step 7: Printing Status of Migrations

Since we read the schema_migration and set the done flag during initialization itself, we can now simply loop through the migration and print its status.

package migrations

func (m *Migrator) MigrationStatus() error {
	for _, v := range m.Versions {
		mg := m.Migrations[v]

		if mg.done {
			fmt.Println(fmt.Sprintf("Migration %s... completed", v))
		} else {
			fmt.Println(fmt.Sprintf("Migration %s... pending", v))
		}
	}

	return nil
}

Step 8: Create the CLI using Cobra

The final step is to create a CLI tool around our migration tool that will allow users to run migration using commands.

We will add support for the following 4 commands.

go run main.go migrate create -n migration_name
go run main.go migrate status
go run main.go migrate up [-s 2]
go run main.go migrate down [-s 2]

We will use Cobra for creating the CLI. Go ahead and install it.

go get -u github.com/spf13/cobra/cobra

Explaining this part of the code will be out of the scope of this article. If you want to know more about Cobra please read their documentation.

We will keep all the CLI code inside the cmd folder. Now let us first create the root command.

package cmd

import (
	"log"

	"github.com/spf13/cobra"
)

var rootCmd = &cobra.Command{
	Use:   "app",
	Short: "Application Description",
}

// Execute ..
func Execute() {
	if err := rootCmd.Execute(); err != nil {
		log.Fatalln(err.Error())
	}
}

Now that we have the root command, we can go ahead and create our main.go file where we will simply execute this root command.

package main

import "github.com/praveen001/go-db-migration/cmd"

func main() {
	cmd.Execute()
}

Once we have created the main.go and the root command, we can create the migrate command.

package cmd

import (
	"fmt"

	"github.com/praveen001/go-db-migration/migrations"
	"github.com/spf13/cobra"
)

var migrateCmd = &cobra.Command{
	Use:   "migrate",
	Short: "database migrations tool",
	Run: func(cmd *cobra.Command, args []string) {

	},
}

Finally, we can create and add the create, up, down and status commands to the migrate command.

package cmd

// Code removed for brevity

var migrateCreateCmd = &cobra.Command{
	Use:   "create",
	Short: "create a new empty migrations file",
	Run: func(cmd *cobra.Command, args []string) {
		name, err := cmd.Flags().GetString("name")
		if err != nil {
			fmt.Println("Unable to read flag `name`", err.Error())
			return
		}

		if err := migrations.Create(name); err != nil {
			fmt.Println("Unable to create migration", err.Error())
			return
		}
	},
}

var migrateUpCmd = &cobra.Command{
	Use:   "up",
	Short: "run up migrations",
	Run: func(cmd *cobra.Command, args []string) {

		step, err := cmd.Flags().GetInt("step")
		if err != nil {
			fmt.Println("Unable to read flag `step`")
			return
		}

		db := app.NewDB()

		migrator, err := migrations.Init(db)
		if err != nil {
			fmt.Println("Unable to fetch migrator")
			return
		}

		err = migrator.Up(step)
		if err != nil {
			fmt.Println("Unable to run `up` migrations")
			return
		}

	},
}

var migrateDownCmd = &cobra.Command{
	Use:   "down",
	Short: "run down migrations",
	Run: func(cmd *cobra.Command, args []string) {

		step, err := cmd.Flags().GetInt("step")
		if err != nil {
			fmt.Println("Unable to read flag `step`")
			return
		}

		db := app.NewDB()

		migrator, err := migrations.Init(db)
		if err != nil {
			fmt.Println("Unable to fetch migrator")
			return
		}

		err = migrator.Down(step)
		if err != nil {
			fmt.Println("Unable to run `down` migrations")
			return
		}
	},
}

var migrateStatusCmd = &cobra.Command{
	Use:   "status",
	Short: "display status of each migrations",
	Run: func(cmd *cobra.Command, args []string) {
		db := app.NewDB()

		migrator, err := migrations.Init(db)
		if err != nil {
			fmt.Println("Unable to fetch migrator")
			return
		}

		if err := migrator.MigrationStatus(); err != nil {
			fmt.Println("Unable to fetch migration status")
			return
		}

		return
	},
}

func init() {
	// Add "--name" flag to "create" command
	migrateCreateCmd.Flags().StringP("name", "n", "", "Name for the migration")

	// Add "--step" flag to both "up" and "down" command
	migrateUpCmd.Flags().IntP("step", "s", 0, "Number of migrations to execute")
	migrateDownCmd.Flags().IntP("step", "s", 0, "Number of migrations to execute")

	// Add "create", "up" and "down" commands to the "migrate" command
	migrateCmd.AddCommand(migrateUpCmd, migrateDownCmd, migrateCreateCmd, migrateStatusCmd)

	// Add "migrate" command to the root command
	rootCmd.AddCommand(migrateCmd)
}

That is everything that we need to create a DB Migration tool.

Demo: DB Migrations in Go

Now that we have finished creating the tool, let us try to create and run a migration.

We will first create a new migration file by running the create command.

go run main.go migrate create -n init_schema

Let us open the newly created migration file and write our schema migration queries to create a new table users with one column name.

package migrations

import "database/sql"

func init() {
	migrator.AddMigration(&Migration{
		Version: "20200830120717",
		Up:      mig_20200830120717_init_schema_up,
		Down:    mig_20200830120717_init_schema_down,
	})
}

func mig_20200830120717_init_schema_up(tx *sql.Tx) error {
	_, err := tx.Exec("CREATE TABLE users ( name varchar(255) );")
	if err != nil {
		return err
	}
	return nil
}

func mig_20200830120717_init_schema_down(tx *sql.Tx) error {
	_, err := tx.Exec("DROP TABLE users")
	if err != nil {
		return err
	}
	return nil
}

We can now try and execute the migration by running the following command.

go run main.go migrate up

If you see output like the following, the migration has finished successfully.

running db migrations in go

We can check our database to see if the schema changes were applied as expected.

database output showing db migration in go

We can check the status of migrations by running the following command.

go run main.go migrate status

And that should produce an output like the following.

output of migrate status command

Finally, let us try reverting the schema changes by running the down command.

go run main.go migrate down

If you see an output like the following, the migration was reverted successfully.

output of migration down command

Now, if you check the database you can see the table is dropped and the schema_migration table is empty again.

database reflecting schema changes

Awesome! Everything seems to be working fine. That gives us our very own tool to do DB migrations in Go that will allow us to write migrations in a .go file.

You can see the full source code here: github.com/praveen001/go-db-migration

Write the first response