テーブルデータ生成AI CTGAN(その2)
主キーを指定する

テーブルデータ生成AI CTGAN(その2)主キーを指定する

分析で利用するテーブルデータが少ないことがあります。

もう少し増やせないだろうか、と夢見ることも少なくないでしょう。

ここ最近、色々な生成AIが登場してきました。

そこで使われている技術の1つにGAN(敵対的生成ネットワーク)というものがあります。あらかじめ準備されたデータをもとに、擬似的なデータを生成することができます。

テーブルデータを生成するGANもあります。CTGAN(Conditional General Adversarial Networks)です。

前回、Pythonを使いテーブルデータ生成AI CTGANで、簡単な例で使い方を説明しました。

テーブルデータ生成AI CTGAN(その1)取り急ぎデータ量を増やす

 

CTGAN(Conditional General Adversarial Networks)で、手軽にテーブルデータを生成することができますが、無邪気に生成すると、ちょっと変なデータが生成されることがあります。

例えば、データセットの主キーです。

主キーは、データベースのテーブルの行を一意に識別するために使用されるキーで、重複する値を持たず、常に一意である必要があります。

無邪気に生成すると、主キーが一意でないデータセットが生成される危険性があります。

今回は、主キーを指定しテーブルデータ生成AI CTGANで、テーブルデータを生成する方法を説明します。

PythonのSDVパッケージを使いますので、インストールされていないかたは、前回の記事を参考にインストールしてください。

利用するデータセット

今回利用する データセットは、大学の学生の就職状況に関するデータ(Campus Recruitment)です。

以下からもダウンロードできます。

Placement_Data_Full_Class.csv
https://www.salesanalytics.co.jp/su1r

このデータセットは、学生の学業成績、科目、仕事の経験、専門分野など、学生の就職に影響を与える要因や、学生の就職を予測するモデルを構築するために使用されたりします。

以下の、変数sl_noを主キーとするデータセットです。

  • sl_no: 学生番号 ※主キー
  • gender: 学生の性別
  • ssc_p: 義務教育10年目の試験の得点率
  • ssc_b: 義務教育10年目の試験を実施する委員会
  • hsc_p:義務教育12年目の試験の得点率
  • hsc_b: 義務教育12年目試験を実施する委員会
  • hsc_s:義務教育12年目の試験の専攻
  • degree_p: 工学学位の試験で得た得点の割合
  • degree_t: 工学学位の専攻
  • workex: 学生の職務経験の有無
  • etest_p: 入学試験における学生の得点率
  • specialisation: 学生の専門分野
  • mba_p: MBA試験における学生の得点率
  • status: 学生が就職したかどうか(Placed:内定した、Not Placed:内定しなかった
  • salary: 学生に提示された給与

 

このデータセットには、次の2つの制約もしくはルールがあります。

  1. sl_no(学生番号)は、レコードごとに異なり重複しない
  2. salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある

 

今回は、sl_no(学生番号)が重複しないように、テーブルデータ生成することを考えます。もう一方は、次回扱います。

 

必要なモジュールの読み込み

必要なモジュールを読み込みます。

以下、コードです。

# 基本モジュール
import pandas as pd
import numpy as np

# SVD
from sdv.single_table import CTGANSynthesizer
from sdv.metadata import SingleTableMetadata

import warnings
warnings.simplefilter('ignore')

 

データセット読み込み

データセットを読み込みます。

以下、コードです。

dataset = 'Placement_Data_Full_Class.csv'
real_data = pd.read_csv(dataset)

print(real_data)

 

以下、実行結果です。

     sl_no gender  ssc_p    ssc_b  hsc_p    hsc_b     hsc_s  degree_p  \
0        1      M  67.00   Others  91.00   Others  Commerce     58.00   
1        2      M  79.33  Central  78.33   Others   Science     77.48   
2        3      M  65.00  Central  68.00  Central      Arts     64.00   
3        4      M  56.00  Central  52.00  Central   Science     52.00   
4        5      M  85.80  Central  73.60  Central  Commerce     73.30   
..     ...    ...    ...      ...    ...      ...       ...       ...   
210    211      M  80.60   Others  82.00   Others  Commerce     77.60   
211    212      M  58.00   Others  60.00   Others   Science     72.00   
212    213      M  67.00   Others  67.00   Others  Commerce     73.00   
213    214      F  74.00   Others  66.00   Others  Commerce     58.00   
214    215      M  62.00  Central  58.00   Others   Science     53.00   

      degree_t workex  etest_p specialisation  mba_p      status    salary  
0     Sci&Tech     No     55.0         Mkt&HR  58.80      Placed  270000.0  
1     Sci&Tech    Yes     86.5        Mkt&Fin  66.28      Placed  200000.0  
2    Comm&Mgmt     No     75.0        Mkt&Fin  57.80      Placed  250000.0  
3     Sci&Tech     No     66.0         Mkt&HR  59.43  Not Placed       NaN  
4    Comm&Mgmt     No     96.8        Mkt&Fin  55.50      Placed  425000.0  
..         ...    ...      ...            ...    ...         ...       ...  
210  Comm&Mgmt     No     91.0        Mkt&Fin  74.49      Placed  400000.0  
211   Sci&Tech     No     74.0        Mkt&Fin  53.62      Placed  275000.0  
212  Comm&Mgmt    Yes     59.0        Mkt&Fin  69.72      Placed  295000.0  
213  Comm&Mgmt     No     70.0         Mkt&HR  60.23      Placed  204000.0  
214  Comm&Mgmt     No     89.0         Mkt&HR  60.22  Not Placed       NaN  

[215 rows x 15 columns]

 
レコード数は215で、変数の数は15です。

 

データセットの特徴把握

読み込んだデータセットの情報を見てみます。

以下、コードです。

real_data.info()

 

以下、実行結果です。

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 215 entries, 0 to 214
Data columns (total 15 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   sl_no           215 non-null    int64  
 1   gender          215 non-null    object 
 2   ssc_p           215 non-null    float64
 3   ssc_b           215 non-null    object 
 4   hsc_p           215 non-null    float64
 5   hsc_b           215 non-null    object 
 6   hsc_s           215 non-null    object 
 7   degree_p        215 non-null    float64
 8   degree_t        215 non-null    object 
 9   workex          215 non-null    object 
 10  etest_p         215 non-null    float64
 11  specialisation  215 non-null    object 
 12  mba_p           215 non-null    float64
 13  status          215 non-null    object 
 14  salary          148 non-null    float64
dtypes: float64(6), int64(1), object(8)
memory usage: 25.3+ KB

 

salaryを見ていただくと分かりますが、N-n-Null Count(欠測していないデータの数)が148と、215よりも少なく欠測値がある(67名の方が内定をもらっていない)ことが分かります。

 

sl_noが重複していいないかどうかを確認します。

以下、コードです。

real_data['sl_no'].duplicated()

 

以下、実行結果です。重複している行がTrueになっています。

0      False
1      False
2      False
3      False
4      False
       ...  
210    False
211    False
212    False
213    False
214    False
Name: sl_no, Length: 215, dtype: bool

 

Trueの数を数えます。

以下、コードです。

real_data['sl_no'].duplicated().sum()

 

以下、実行結果です。

0

 

0です。

当然ですが、学習で利用するこのデータセットは、sl_noが重複していません。

 

メタデータの取得

テーブルデータを生成するために、先程読み込んだデータセットからメタデータを取得します。

以下、コードです。

# データフレームからメタデータを自動抽出
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(real_data)

 

念のため、メタデータを見てみます。

以下、コードです。

metadata

 

以下、実行結果です。各変数の型が定義されています。

{
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1",
    "columns": {
        "sl_no": {
            "sdtype": "numerical"
        },
        "gender": {
            "sdtype": "categorical"
        },
        "ssc_p": {
            "sdtype": "numerical"
        },
        "ssc_b": {
            "sdtype": "categorical"
        },
        "hsc_p": {
            "sdtype": "numerical"
        },
        "hsc_b": {
            "sdtype": "categorical"
        },
        "hsc_s": {
            "sdtype": "categorical"
        },
        "degree_p": {
            "sdtype": "numerical"
        },
        "degree_t": {
            "sdtype": "categorical"
        },
        "workex": {
            "sdtype": "categorical"
        },
        "etest_p": {
            "sdtype": "numerical"
        },
        "specialisation": {
            "sdtype": "categorical"
        },
        "mba_p": {
            "sdtype": "numerical"
        },
        "status": {
            "sdtype": "categorical"
        },
        "salary": {
            "sdtype": "numerical"
        }
    }
}

 

データ生成その1(主Key設定なし)

主キーは、先程取得したメタデータを修正し指定します。

先ずは、メタデータを修正することなく、テーブルデータを生成していきます。

以下、コードです。

# インスタンス生成
ctgan = CTGANSynthesizer(metadata,epochs=10) 

# 学習
ctgan.fit(real_data) 

# データ生成
synthetic_data = ctgan.sample(20000) 

# 生成したデータを確認
print(synthetic_data)

 

以下、実行結果です。

       sl_no gender  ssc_p    ssc_b  hsc_p    hsc_b     hsc_s  degree_p  \
0        215      M  86.91   Others  67.75  Central  Commerce     50.00   
1        168      M  62.49   Others  94.87  Central   Science     51.76   
2        215      F  89.40   Others  61.27  Central  Commerce     59.31   
3        155      F  86.97   Others  96.04   Others   Science     78.99   
4        151      M  85.66   Others  75.40   Others  Commerce     65.19   
...      ...    ...    ...      ...    ...      ...       ...       ...   
19995    164      F  88.70  Central  63.86   Others  Commerce     54.74   
19996    215      M  78.41  Central  63.03  Central  Commerce     63.29   
19997    198      M  76.26  Central  92.06  Central   Science     50.00   
19998     13      M  69.93  Central  86.26   Others   Science     50.00   
19999    215      M  58.20   Others  81.18   Others  Commerce     55.25   

        degree_t workex  etest_p specialisation  mba_p      status    salary  
0      Comm&Mgmt    Yes    75.68        Mkt&Fin  56.84  Not Placed       NaN  
1      Comm&Mgmt     No    70.20         Mkt&HR  59.94  Not Placed  200000.0  
2         Others     No    73.44        Mkt&Fin  52.34      Placed       NaN  
3         Others     No    74.43        Mkt&Fin  55.45      Placed       NaN  
4       Sci&Tech    Yes    64.69        Mkt&Fin  52.47  Not Placed       NaN  
...          ...    ...      ...            ...    ...         ...       ...  
19995  Comm&Mgmt    Yes    93.29         Mkt&HR  54.85      Placed  306046.0  
19996  Comm&Mgmt    Yes    55.07        Mkt&Fin  68.71  Not Placed  317092.0  
19997  Comm&Mgmt     No    98.00         Mkt&HR  51.21  Not Placed  485863.0  
19998   Sci&Tech    Yes    64.21         Mkt&HR  51.21      Placed  268246.0  
19999  Comm&Mgmt     No    51.46         Mkt&HR  57.42      Placed  399696.0  

[20000 rows x 15 columns]

 

生成したテーブルデータのsl_noが重複していいないかどうかを確認します。

以下、コードです。

synthetic_data['sl_no'].duplicated().sum()

 

以下、実行結果です。

19785

 

19785です。かなりsl_noが重複しています。

 

メタデータの修正(主キー設定)

メタデータを修正し主キーを設定します。

主キーに設定する変数は、例えば「id」型である必要があるため、先ず型変換を行い、次に主キー設定をします。

以下、コードです。

# 変数の型の変更
metadata.update_column(
    column_name='sl_no',
    sdtype='id')

# 主Key設定
metadata.set_primary_key(column_name="sl_no")

 

メターデータを確認してみます。

以下、コードです。

metadata

 

以下、実行結果です。

{
    "primary_key": "sl_no",
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1",
    "columns": {
        "sl_no": {
            "sdtype": "id"
        },
        "gender": {
            "sdtype": "categorical"
        },
        "ssc_p": {
            "sdtype": "numerical"
        },
        "ssc_b": {
            "sdtype": "categorical"
        },
        "hsc_p": {
            "sdtype": "numerical"
        },
        "hsc_b": {
            "sdtype": "categorical"
        },
        "hsc_s": {
            "sdtype": "categorical"
        },
        "degree_p": {
            "sdtype": "numerical"
        },
        "degree_t": {
            "sdtype": "categorical"
        },
        "workex": {
            "sdtype": "categorical"
        },
        "etest_p": {
            "sdtype": "numerical"
        },
        "specialisation": {
            "sdtype": "categorical"
        },
        "mba_p": {
            "sdtype": "numerical"
        },
        "status": {
            "sdtype": "categorical"
        },
        "salary": {
            "sdtype": "numerical"
        }
    }
}

 

primary_keyが主キーを表します。sl_noになっているのが分かるかと思います。

sl_no型(sdtype)がidになっているのも分かるかと思います。

この状態で、テーブルデータを生成していきます。

 

データ生成その2(主キー設定あり)

sl_no主キーに設定したメタデータを使い、テーブルデータを生成します。

以下、コードです。

# インスタンス生成
ctgan = CTGANSynthesizer(metadata,epochs=10) 

# 学習
ctgan.fit(real_data) 

# データ生成
synthetic_data = ctgan.sample(20000) 

# 生成したデータを確認
print(synthetic_data)

 

以下、実行結果です。

       sl_no gender  ssc_p    ssc_b  hsc_p    hsc_b     hsc_s  degree_p  \
0          0      M  63.32  Central  51.18   Others      Arts     78.93   
1          1      F  65.35   Others  66.10  Central   Science     67.23   
2          2      F  40.89  Central  47.76  Central   Science     67.20   
3          3      F  40.89   Others  60.37   Others  Commerce     89.05   
4          4      F  56.19  Central  39.60  Central  Commerce     63.17   
...      ...    ...    ...      ...    ...      ...       ...       ...   
19995  19995      M  64.25  Central  43.54   Others      Arts     77.72   
19996  19996      F  76.01  Central  70.35   Others  Commerce     71.96   
19997  19997      F  72.59  Central  97.70   Others      Arts     67.86   
19998  19998      M  51.43   Others  56.94  Central   Science     59.97   
19999  19999      M  56.75  Central  52.63  Central  Commerce     50.00   

        degree_t workex  etest_p specialisation  mba_p      status    salary  
0       Sci&Tech    Yes    65.37        Mkt&Fin  65.79      Placed  282434.0  
1       Sci&Tech    Yes    84.75         Mkt&HR  57.03      Placed  374237.0  
2       Sci&Tech     No    58.12        Mkt&Fin  51.21      Placed       NaN  
3      Comm&Mgmt    Yes    82.79        Mkt&Fin  52.39      Placed  200000.0  
4       Sci&Tech     No    71.98         Mkt&HR  53.52  Not Placed       NaN  
...          ...    ...      ...            ...    ...         ...       ...  
19995   Sci&Tech    Yes    92.09        Mkt&Fin  53.48      Placed       NaN  
19996  Comm&Mgmt    Yes    82.98         Mkt&HR  55.71  Not Placed       NaN  
19997   Sci&Tech    Yes    59.22         Mkt&HR  62.57      Placed  457765.0  
19998   Sci&Tech    Yes    75.93         Mkt&HR  51.21      Placed  293694.0  
19999     Others     No    50.00        Mkt&Fin  51.21      Placed  394105.0 

[20000 rows x 15 columns]

 

生成したテーブルデータのsl_noが重複していいないかどうかを確認します。

以下、コードです。

synthetic_data['sl_no'].duplicated().sum()

 

以下、実行結果です。

0

 

0です。sl_noが重複していません。

 

ここで、もう1つの制約というかルールであった「salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある」が保持されているかどうか見てみます。

先ず、statussalaryの変数だけ抜き出して見てみます。

以下、コードです。

print(synthetic_data.loc[:,['status','salary']])

 

以下、実行結果です。

           status    salary
0          Placed  282434.0
1          Placed  374237.0
2          Placed       NaN
3          Placed  200000.0
4      Not Placed       NaN
...           ...       ...
19995      Placed       NaN
19996  Not Placed       NaN
19997      Placed  457765.0
19998      Placed  293694.0
19999      Placed  394105.0

[20000 rows x 2 columns]

 

status変数Placed(内定)なのにsalary変数が欠測しているなど、上手く行っていない様子が分かります。

 

Placed(内定)かNot Placedどうかで、salary変数基本統計量がどうなっているか見てみます。

以下、コードです。

print(synthetic_data.loc[:,['status','salary']].groupby('status').describe())

 

以下、実行結果です。

            salary                                                     \
             count           mean            std       min        25%   
status                                                                  
Not Placed  6134.0  397787.475383  158847.556333  200000.0  296821.75   
Placed      7679.0  399903.842427  158153.818051  200000.0  297113.00   

                                          
                 50%       75%       max  
status                                    
Not Placed  363193.0  452395.5  940000.0  
Placed      368274.0  457405.0  940000.0

 

Not Placed(内定がでていない)なのにsalary変数に値のある(給与が提示された)レコードがあることが分かります。

要するに、「salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある」が保持されていません。

 

まとめ

今回は、主キーを指定しテーブルデータ生成AI CTGANで、テーブルデータを生成する方法を説明しました。

メタデータの設定を変えることで簡単に対応できます。

テーブルデータには、主キーにも色々な制約やルールが課されていることがあります。

次回は、今回と同じデータセットを使い、変数間の関係性に関する制約・ルールを定義し、テーブルデータを生成する方法について説明します。

要は、「salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある」を制約・ルールを課すということです。

テーブルデータ生成AI CTGAN(その3)カスタム制約を与える