ここでは Pandasのsample()
メソッド について、使い方から応用例まで詳しく解説します。
sample()
はデータフレームやシリーズからランダムにデータを抽出するための非常に便利なメソッドです。
目次
基本構文
# セクション0:共通セットアップ
import numpy as np
import pandas as pd
# 再現性のためのNumPyシード(pandasのrandom_stateとは別ですが、ダミーデータの再現に使用)
np.random.seed(0)
# ダミーデータの作成
N = 1000
df = pd.DataFrame({
"id": np.arange(1, N + 1),
"category": np.random.choice(["A", "B", "C"], size=N, p=[0.5, 0.3, 0.2]),
"value": np.random.normal(loc=50, scale=10, size=N)
})
# 重み(例:valueが大きいほど当たりやすくする。ただし非負で)
# 実運用では事前に正規化不要。pandas側で内部的に正規化してくれる。
w = (df["value"] - df["value"].min()) + 1e-6
df["weight"] = w
print(df.head(), "\n", df.shape)
主な引数の説明
n
と frac
n
: 抽出する行数を指定frac
: 全体に対する割合を指定(0〜1)
※n
とfrac
は同時に指定できません。
# 基本(nとfrac)
import pandas as pd
# 5行をランダム抽出(n)
sample_n = df.sample(n=5)
print("=== n=5 ===")
print(sample_n, "\n")
# 全体の10%をランダム抽出(frac)
sample_frac = df.sample(frac=0.10)
print("=== frac=0.10 ===")
print(sample_frac.shape) # 件数確認
replace
False
(デフォルト): 非復元抽出(同じ行は選ばれない)True
: 復元抽出(同じ行を何度も選べる)
# 復元抽出(replace=True)とブートストラップ
# データ数と同じ件数を復元抽出するとブートストラップサンプルになる
bootstrap = df.sample(n=len(df), replace=True, random_state=123)
dup_rate = 1 - bootstrap["id"].nunique() / len(bootstrap)
print("=== ブートストラップ ===")
print(bootstrap.head())
print(f"重複率の目安(同じ行が何度も選ばれている率): {dup_rate:.2%}")
weights
- 各行に対して抽出確率を設定できる。リストや列名を指定可能。
- 値は非負でなくてはならない。
- 例:
"prob"
列に格納された確率に基づいて抽出する。
# 重み付きサンプリング(weights)
# 1) 列名で指定
weighted_sample_col = df.sample(n=5, weights="weight", random_state=0)
print("=== weights='weight'(列名) ===")
print(weighted_sample_col[["id", "value", "weight"]], "\n")
# 2) 配列で指定(長さは行数と同じにする)
weights_array = np.where(df["category"].eq("A"), 5.0,
np.where(df["category"].eq("B"), 2.0, 1.0))
weighted_sample_arr = df.sample(n=5, weights=weights_array, random_state=0)
print("=== weights=配列(Aを特に当たりやすく) ===")
print(weighted_sample_arr[["id", "category"]])
random_state
- 乱数シードを固定して再現性を確保できる。
# 再現性(random_state)
s1 = df.sample(n=5, random_state=42)
s2 = df.sample(n=5, random_state=42) # 同じrandom_stateなら同じ結果
s3 = df.sample(n=5, random_state=7) # 乱数シードが異なれば異なる結果
print("=== random_state=42 サンプル1 ===")
print(s1[["id", "category", "value"]], "\n")
print("=== random_state=42 サンプル2(同一) ===")
print(s2[["id", "category", "value"]], "\n")
print("=== random_state=7 サンプル(異なる) ===")
print(s3[["id", "category", "value"]])
axis
axis=0
(デフォルト): 行を抽出axis=1
: 列を抽出
# 列方向のサンプリング(axis=1)
# 列をランダムに3つ抽出
cols_sampled = df.sample(n=3, axis=1, random_state=1)
print("=== 列サンプル(3列) ===")
print(cols_sampled.head())
ignore_index
True
にすると抽出後にインデックスを振り直す。
# インデックス振り直し(ignore_index)
s = df.sample(n=5, random_state=0, ignore_index=True)
print("=== ignore_index=True ===")
print(s.head())
print("Index連番:", s.index.tolist())
注意点
n
がデータ数より大きい場合、replace=False
ではエラーになります。weights
にNaNや負の値があるとエラー。- 大規模データの場合は抽出時に計算コストがかかるので、必要に応じて分割や分散処理を検討。
まとめ
sample()
は単なる「ランダム抽出」以上に、
- ランダムサンプリング
- ブートストラップ
- データ分割
- 重み付きサンプリング
など多用途に活用できるメソッドです。
以上、PythonのPandasのsampleメソッドについてでした。
最後までお読みいただき、ありがとうございました。