テンサーの精度検証をレベルアップ!PyTorchのtorch.testing.assert_close()関数と代替方法

2024-06-29

PyTorchにおける torch.testing.assert_close() の詳細解説

動作原理

torch.testing.assert_close() は、以下の式に基づいて2つのテンサー actualexpected を比較します。

|actual - expected| <= atol + rtol * |expected|

ここで、

  • atol: 絶対許容誤差を表します。これは、expected の値に関係なく、actualexpected の間の許容される最大差を表します。
  • rtol: 相対許容誤差を表します。これは、expected の値の一定割合を許容される最大差として設定します。

2つのテンサーが上記の式を満たす場合、assert_close() は成功し、テストは合格となります。一方、式を満たさない場合、assert_close() は失敗し、テストは不合格となります。

主な引数

torch.testing.assert_close() は、以下の引数を受け取ります。

  • actual: 比較対象となる最初のテンサー
  • expected: 比較対象となる2番目のテンサー
  • atol: 絶対許容誤差(デフォルト: 1e-08
  • rtol: 相対許容誤差(デフォルト: 1e-05
  • equal_nan: Trueの場合、NaN値も等しいとみなされます(デフォルト: False)
  • check_dtype: Trueの場合、2つのテンサーのデータ型が一致していることを確認します(デフォルト: True)
  • check_device: Trueの場合、2つのテンサーが同じデバイス上にあることを確認します(デフォルト: True)

使用例

以下の例は、torch.testing.assert_close() を使用して2つのテンサーを比較する方法を示しています。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0])
expected = torch.tensor([1.0001, 2.0002, 3.0003])

# 許容誤差を設定
atol = 1e-3
rtol = 1e-4

# テストを実行
tt.assert_close(actual, expected, atol=atol, rtol=rtol)

上記の例では、atolrtol をそれぞれ 1e-31e-4 に設定しています。これは、expected の値の0.1%未満の差であれば許容されることを意味します。

  • torch.testing.assert_close() は、torch.allclose() と似ていますが、いくつかの重要な違いがあります。
    • torch.allclose() は、テンサー内のすべての要素が等しいかどうかを検証します。一方、torch.testing.assert_close() は、許容範囲内で近似的に等しいかどうかを検証します。
    • torch.allclose() は、NaN値を考慮しません。一方、torch.testing.assert_close() は、equal_nan オプションを指定することでNaN値を考慮することができます。
  • torch.testing.assert_close() は、PyTorchのテストにおいて、テンサーの精度を検証するための強力なツールです。許容誤差を適切に設定することで、モデルの訓練や推論の過程で生成された結果の信頼性を確認することができます。


基本的な例

この例では、2つのテンサーを絶対許容誤差と相対許容誤差を使用して比較します。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0])
expected = torch.tensor([1.0001, 2.0002, 3.0003])

# 許容誤差を設定
atol = 1e-3
rtol = 1e-4

# テストを実行
tt.assert_close(actual, expected, atol=atol, rtol=rtol)

絶対許容誤差のみを使用する例

この例では、絶対許容誤差のみを使用して2つのテンサーを比較します。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0])
expected = torch.tensor([1.001, 2.002, 3.003])

# 許容誤差を設定
atol = 1e-2

# テストを実行
tt.assert_close(actual, expected, atol=atol)

相対許容誤差のみを使用する例

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([100.0, 200.0, 300.0])
expected = torch.tensor([100.1, 200.2, 300.3])

# 許容誤差を設定
rtol = 1e-3

# テストを実行
tt.assert_close(actual, expected, rtol=rtol)

NaN値を考慮する例

この例では、equal_nan オプションを使用して、NaN値を考慮したテンサーの比較を行います。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, float('nan')])
expected = torch.tensor([1.0001, 2.0002, float('nan')])

# テストを実行
tt.assert_close(actual, expected, equal_nan=True)

データ型のチェックを無効にする例

この例では、check_dtype オプションを使用して、データ型のチェックを無効にします。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
expected = torch.tensor([1.0001, 2.0002, 3.0003], dtype=torch.float64)

# テストを実行
tt.assert_close(actual, expected, check_dtype=False)

デバイスのチェックを無効にする例

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0], device='cpu')
expected = torch.tensor([1.0001, 2.0002, 3.0003], device='cuda')

# テストを実行
tt.assert_close(actual, expected, check_device=False)

これらの例は、torch.testing.assert_close() の様々な使用方法を示しています。具体的な状況に合わせて、適切なオプションを設定して使用してください。

  • torch.testing.assert_close() は、PyTorchのテストにおいて、テンサーの精度を検証するための強力なツールです。許容誤差を適切に設定することで、モデルの訓練や推論の過程で生成された結果の信頼性を確認することができます。


torch.testing.assert_close() の代替方法

torch.allclose() は、torch.testing.assert_close() と似ていますが、以下の点が異なります。

  • 絶対許容誤差と相対許容誤差の代わりに、許容誤差の閾値のみを指定します。
  • デフォルトの許容誤差は atol=1e-09rtol=1e-5 です。
  • NaN値を考慮しません。

以下の例は、torch.allclose() を使用して2つのテンサーを比較する方法を示しています。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0001, 2.0002, 3.0003])
expected = torch.tensor([1.0, 2.0, 3.0])

# 許容誤差を設定
rtol = 1e-3

# テストを実行
tt.allclose(actual, expected, rtol=rtol)

カスタム断言関数:

許容誤差の計算ロジックや、NaN値の扱いなどを細かく制御したい場合は、カスタム断言関数を作成することもできます。

以下の例は、カスタム断言関数を使用して2つのテンサーを比較する方法を示しています。

import torch

def my_assert_close(actual, expected, atol=1e-3, rtol=1e-5, equal_nan=False):
    # 絶対許容誤差と相対許容誤差に基づいて差を計算
    diff = actual - expected
    abs_diff = torch.abs(diff)
    rel_diff = abs_diff / torch.abs(expected)

    # 許容範囲内に収まっているかどうかを確認
    if (diff <= atol) or (rel_diff <= rtol):
        return True
    else:
        if equal_nan and torch.isnan(diff).any():
            return True
        else:
            return False

# テスト対象のテンサーを作成
actual = torch.tensor([1.0001, 2.0002, 3.0003])
expected = torch.tensor([1.0, 2.0, 3.0])

# テストを実行
my_assert_close(actual, expected)

NumPyやSciPyなどの他のライブラリも、テンサーの比較に使用することができます。

  • NumPy: np.allclose()
  • SciPy: scipy.allclose()

これらのライブラリの関数は、PyTorchの torch.testing.assert_close() とは異なるオプションや機能を提供している場合があります。

選択の指針:

  • シンプルで使いやすい: torch.testing.assert_close()
  • 許容誤差の計算ロジックを制御したい: カスタム断言関数
  • 特定のライブラリとの整合性を保ちたい: NumPyやSciPy

注意事項:

  • torch.allclose()torch.testing.assert_close() と完全に互換性があるわけではありません。
  • カスタム断言関数を作成する場合は、テストのロジックが明確で分かりやすいように記述することが重要です。
  • 他のライブラリの関数は、PyTorchのテンサーとは異なる形式でデータを受け入れる場合があるため、注意が必要です。

これらの代替方法を理解することで、状況に応じて適切な方法を選択し、テストコードをより柔軟かつ効率的に記述することができます。