Go の静的解析で DB へのコミット漏れを検出する

エンジニアの佐野です。カンムはバックエンドに PostgreSQL を置きつつサーバを Go で書いています。DB のトランザクションの取り回しは概ね次の様なイディオムになっているのですが、先日 Commit() が漏れている箇所を見つけまして...。結果としてそれについては大きな問題はなく秋の夜長に遅めの肝試しをする程度で済んだのですが、これは事故に繋がるためトランザクションの Commit 漏れ(defer Rollback() 漏れも)を検出する Linter を書きました。

  tx, err := db.BeginTx(ctx, nil)
  if err != nil {
    return err
  }
  defer tx.Rollback()

  // ...

  if err := tx.Commit(); err != nil {
    return err
  }

  // ...

Linter の方針

次のような方針とします

  • 意図的にコミットを書かず Rollback() させる前提の実装もあるかもしれないが、Commit() が書かれていないものは指摘の対象とする。ようは tx, err := Begin() もしくは tx, err := BeginTx() があったらそれと対になる defer tx.Rollback() と tx.Commit() のペアがちゃんと実装されているかチェックする。
  • Go での DB トランザクションの取り回し方はいろいろあるとは思うがカンムで主流のイディオムになっている前提とする。これはいきなり汎用的なツールとせずにまずはカンム内の問題を解決することに焦点を当てるため。

txchk

上記方針にて txchk という Linter を書きました。txchk は golang.org/x/tools/go/analysis.Analyzer を利用して開発しています。golang.org/x/tools/go/analysis.Analyzer は静的解析用のパッケージでこれを使うと比較的簡単に Linter を書くことができます。またコマンドの実装を次のように golang.org/x/tools/go/analysis/singlechecker.Main() 経由で Analyzer を呼ぶようにすることで go vet に組み込んで go vet -vettool=$(which txchk) ./... として使えるようにすることもできます。

package main

import (
        "github.com/kanmu/txchk"
        "golang.org/x/tools/go/analysis/singlechecker"
)

func main() { singlechecker.Main(txchk.Analyzer) }
go vet -vettool=$(which txchk) ./...

以下が txchk を書いて Go のコミット漏れを調査してその漏れを一掃したパッチです。結構見つかりました...冒頭およびPR 本文にも書いてありますが結果として漏れていても問題はなかったです。現在は txchk を CI の lint ステージに組み込んで go vet 経由で動かしています。

txchk の処理の流れと txchk の概観

処理の流れは次の通りです。

  • ターゲットとなる Go のリポジトリのすべての関数を対象に各関数を先頭から走査
  • Begin() または BeginTx() している箇所を見つける
  • Begin() または BeginTx() が見つかったらその左辺を走査して1つめの戻り値の変数名 (例: tx) を記憶
  • それ以降の処理を走査して tx.Commit() を呼んでいればコミットされているとみなす
  • 関数を最後まで走査し、tx.Commit() が無かったら Lint エラーとする
  • (Rollback() についても上記と同様の処理を行う)

この処理の概観は以下の通りです。

package txchk

import (
  "go/ast"
  "go/token"
  "strings"

  "golang.org/x/tools/go/analysis"
  "golang.org/x/tools/go/analysis/passes/inspect"
  "golang.org/x/tools/go/ast/inspector"
)

var Analyzer = &analysis.Analyzer{
  Name: "txchk",
  Doc:  Doc,
  Run:  run,
  Requires: []*analysis.Analyzer{
    inspect.Analyzer,
  },
}

const Doc = "txchk"

func run(pass *analysis.Pass) (interface{}, error) {
  inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

  nodeFilter := []ast.Node{
    (*ast.FuncDecl)(nil),
  }

  inspect.Preorder(nodeFilter, func(n ast.Node) {
    pos := pass.Fset.Position(n.Pos())

    switch node := n.(type) {
    case *ast.FuncDecl:
      if strings.HasPrefix(node.Name.Name, "Test") || strings.HasPrefix(node.Name.Name, "test") {
        return
      }
      for _, stmt := range node.Body.List {
        switch s := stmt.(type) {
        case *ast.AssignStmt:
          for _, expr := range s.Rhs {
            if callExpr, ok := expr.(*ast.CallExpr); ok {
              found := findTransactionTypeBegin(callExpr)
              if found {
                beginPos := pass.Fset.Position(s.Pos()).Line
                for i, lh := range s.Lhs {
                  if i == 0 {
                    if ident, ok := lh.(*ast.Ident); ok {
                      committed := isCommitImplemented(pass.Fset, node.Body, beginPos, ident.Name)
                      if !committed {
                        pass.Reportf(s.Pos(), "transaction must be committed: %s", ident.Name)
                      }
                      rollbacked := isRollbackImplemented(pass.Fset, node.Body, beginPos, ident.Name)
                      if !rollbacked {
                        pass.Reportf(s.Pos(), "transaction must be rollbacked: %s", ident.Name)
                      }
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  })
  return nil, nil
}

処理の解説

ひとつずつ解説していきます。まず以下の部分は analysis.Analyzer を使って Linter を書くときのガワのようなものです。

var Analyzer = &analysis.Analyzer{
  Name: "txchk",
  Doc:  Doc,
  Run:  run,
  Requires: []*analysis.Analyzer{
    inspect.Analyzer,
  },
}

const Doc = "txchk"

func run(pass *analysis.Pass) (interface{}, error) {
  inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

  nodeFilter := []ast.Node{
    // フィルタ
  }

  var err error
  inspect.Preorder(nodeFilter, func(n ast.Node) {
    // 好きな処理
  }
  return nil, err
}

「フィルタ」にて処理したい箇所を指定、その箇所について「好きな処理」を書きます。今回はすべての関数を対象にしたいので (*ast.FuncDecl)(nil) をフィルタに入れていますが、例えばソースコード内に定義された構造体をターゲットに対して何らかの静的解析を施したい場合は (*ast.StructType)(nil) をフィルタに書くことで inspect.Preorder(nodeFilter, func(n ast.Node) には構造体のノードのみが入ってくるようにすることができます。「好きな処理」ではその通りフィルタされたノードについてやりたいことを書いていけばよいです。

続いて「Begin() または BeginTx() している箇所を見つける」処理ですが、ノードが *ast.FuncDecl のものについて node.Body.List を for...range で回すことでその関数を先頭から走査します。そして *ast.AssignStmt 、つまり代入があったとき、その右辺 (s.Rhs) を調査し、その右辺が *ast.CallExpr (関数呼び出し)であれば、それが BeginBeginTx かを調べます。

    switch node := n.(type) {
    case *ast.FuncDecl:
      if strings.HasPrefix(node.Name.Name, "Test") || strings.HasPrefix(node.Name.Name, "test") {
        return
      }
      for _, stmt := range node.Body.List {
        switch s := stmt.(type) {
        case *ast.AssignStmt:
          for _, expr := range s.Rhs {
            if callExpr, ok := expr.(*ast.CallExpr); ok {
              found := findTransactionTypeBegin(callExpr)
...

findTransactionTypeBegin() の実装は次の様になっています。これは db.Begin() もしくは db.BeginTx() を探す実装で、再起処理をしているのは Begin() や BeginTx() が app.DB.Begin() のようになっていたりする場合もそれを辿って Begin() および BeginTx() を見つけられるようにするためです。

func findTransactionTypeBegin(node *ast.CallExpr) bool {
  if fun, ok := node.Fun.(*ast.SelectorExpr); ok {
    if fun.Sel.Name == "Begin" || fun.Sel.Name == "BeginTx" {
      return true
    }
  }
  if n, ok := node.Fun.(*ast.CallExpr); ok {
    return findTransactionTypeBegin(n)
  }
  return false
}

そして Begin() または BeginTx() が見つかったら左辺 (s.Lhs) を走査して1つめの戻り値の変数名が格納されている ident.Name を「コミットが実装されているか?」を調査する関数 (isCommitImplemented()) に渡して、 ident.Name (例: tx) について tx.Commit() が実装されているかを調べます。

              found := findTransactionTypeBegin(callExpr)
              if found {
                beginPos := pass.Fset.Position(s.Pos()).Line
                for i, lh := range s.Lhs {
                  if i == 0 {
                    if ident, ok := lh.(*ast.Ident); ok {
                      committed := isCommitImplemented(pass.Fset, node.Body, beginPos, ident.Name)
...

isCommitImplemented() は次の通りです。tx.Rollback() を探すときも似たような処理になるので isTransactionFinished() のラッパにしています。TransactionTypeCommit, TransactionTypeRollback はそれぞれ type TransactionType int と iota で定義した列挙です。

func isCommitImplemented(fset *token.FileSet, body *ast.BlockStmt, beginPos int, txIdentName string) bool {
  return isTransactionFinished(fset, body, beginPos, txIdentName, TransactionTypeCommit)
}

func isRollbackImplemented(fset *token.FileSet, body *ast.BlockStmt, beginPos int, txIdentName string) bool {
  return isTransactionFinished(fset, body, beginPos, txIdentName, TransactionTypeRollback)
}

func isTransactionFinished(fset *token.FileSet, body *ast.BlockStmt, beginPos int, txIdentName string, tranType TransactionType) bool {
  for _, stmt := range body.List {
    pos := fset.Position(stmt.Pos())
    if pos.Line < beginPos {
      continue
    }

    switch s := stmt.(type) {
    case *ast.IfStmt:
      if s.Init != nil {
        if assignStmt, ok := s.Init.(*ast.AssignStmt); ok {
          for _, expr := range assignStmt.Rhs {
            if callExpr, ok := expr.(*ast.CallExpr); ok {
              if findTransactionTypeByIdentName(callExpr, txIdentName, tranType) {
                return true
              }
            }
          }
        }
      }
      if isTransactionFinished(fset, s.Body, pos.Line, txIdentName, tranType) {
        return true
      }
    case *ast.BlockStmt:
      if isTransactionFinished(fset, s, pos.Line, txIdentName, tranType) {
        return true
      }
    case *ast.ExprStmt: // tx.Commit(), func() { tx.Commit }()
      if callExpr, ok := s.X.(*ast.CallExpr); ok {
        if findTransactionTypeByIdentName(callExpr, txIdentName, tranType) {
          return true
        }
        if n, ok := callExpr.Fun.(*ast.FuncLit); ok {
          if isTransactionFinished(fset, n.Body, pos.Line, txIdentName, tranType) {
            return true
          }
        }
      }
    case *ast.ReturnStmt: // return tx.Commit()
      for _, rtn := range s.Results {
        if callExpr, ok := rtn.(*ast.CallExpr); ok {
          if findTransactionTypeByIdentName(callExpr, txIdentName, tranType) {
            return true
          }
        }
      }
    case *ast.AssignStmt: // err := tx.Commit()
      for _, expr := range s.Rhs {
        if callExpr, ok := expr.(*ast.CallExpr); ok {
          if findTransactionTypeByIdentName(callExpr, txIdentName, tranType) {
            return true
          }
          if n, ok := callExpr.Fun.(*ast.FuncLit); ok {
            if isTransactionFinished(fset, n.Body, pos.Line, txIdentName, tranType) {
              return true
            }
          }
        }
      }
    case *ast.DeferStmt:
      if findTransactionTypeByIdentName(s.Call, txIdentName, tranType) {
        return true
      }
      if n, ok := s.Call.Fun.(*ast.FuncLit); ok {
        if isTransactionFinished(fset, n.Body, 0, txIdentName, tranType) {
          return true
        }
      }
    }
  }
  return false
}

型スイッチが続いているのですが、これは tx.Commit() の書かれ方にいくつかのパターンがありそれを網羅するためです。

// if init で書かれている
if err := tx.Commit(); err != nil {
}

// if ブロックの中
if cond {
  tx.Commit()
}

// 直接リターンしている
return tx.Commit()

// エラーハンドリングなし (しろ)
tx.Commit()

...etc

カンム内ではだいたいこのような書き方をしているのでチーム内で主流なものはカバーします。ちなみに以下のように for...range の中に書くこともできるのですが、このような奇っ怪なものはサポートしないものとします。

  for _, err := range []error{err, tx.Commit()} {
    if err != nil {
      return err
    }
  }

isCommitImplemented() の処理中で登場する findTransactionTypeByIdentName() は次のようになっています。 Begin() や BeginTx() を探す findTransactionTypeBegin() と似ているのですが identName が一致しているかどうかを調べている点で違いがあります。

func findTransactionTypeByIdentName(node *ast.CallExpr, identName string, tranType TransactionType) bool {
  if fun, ok := node.Fun.(*ast.SelectorExpr); ok {
    if fun.Sel.Name == tranType.String() {
      if pkgIndent, ok := fun.X.(*ast.Ident); ok {
        if pkgIndent.Name == identName {
          return true
        }
      }
    }
  }
  if n, ok := node.Fun.(*ast.CallExpr); ok {
    return findTransactionTypeByIdentName(n, identName, tranType)
  }
  return false
}

これは Begin() は1関数内に1つだけとは限らず tx1, err := Begin() と tx2, err := Begin() が存在しているケースがありそれに対応するためです(tx1 の Commit() と tx2 の Commit() をそれぞれ探したい)。

isCommitImplemented() が関数を下まで走査し、tx.Commit() が見つからなかったら pass.Reportf() で Linter エラーを通知します。

                      committed := isCommitImplemented(pass.Fset, node.Body, beginPos, ident.Name)
                      if !committed {
                        pass.Reportf(s.Pos(), "transaction must be committed: %s", ident.Name)
                      }
                      rollbacked := isRollbackImplemented(pass.Fset, node.Body, beginPos, ident.Name)
                      if !rollbacked {
                        pass.Reportf(s.Pos(), "transaction must be rollbacked: %s", ident.Name)
                      }

Rollback() についても同様の処理を行います。以上にて Begin() や BeginTx() が書かれている関数内の処理について、それと対になる tx.Commit() と defer tx.Rollback() が存在しているかを調べる Linter ができあがりました。

まとめ

Go の静的解析のエコシステムに乗ることで比較的簡単に Linter が書けました。レビューで毎回見る箇所や毎回指摘している箇所、チーム内のコーディングルールのようなものがあったらそれを Linter 化してしまうのもありです。それによって書き手もレビュアーもよりビジネスロジックに集中して開発できるようになります。

おわり