Major changes for real this time
This commit is contained in:
113
src/dataset.py
113
src/dataset.py
@@ -3,48 +3,113 @@ import pandas as pd
|
||||
import openml
|
||||
|
||||
# --- CACHE SETUP ---
|
||||
# Change this path to your preferred local cache directory
|
||||
#CACHE_DIR = os.path.expanduser("~/openml_cache")
|
||||
#os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
#openml.config.cache_directory = CACHE_DIR
|
||||
CACHE_DIR = os.path.expanduser("~/openml_cache")
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
openml.config.cache_directory = CACHE_DIR
|
||||
|
||||
# OpenML CC18 classification tasks (task ids)
|
||||
# --- Dataset IDs ---
|
||||
TASKS = {
|
||||
"adult": 7592, # Adult Income classification
|
||||
"spambase": 43, # Spambase classification
|
||||
"optdigits": 28, # Optdigits classification
|
||||
"adult": 7592,
|
||||
"spambase": 43,
|
||||
"optdigits": 28,
|
||||
}
|
||||
|
||||
# Regression dataset (dataset id)
|
||||
DATASETS = {
|
||||
"cal_housing": 44025
|
||||
}
|
||||
|
||||
# --- Load functions ---
|
||||
def _load_task_dataframe(task_id: int):
|
||||
task = openml.tasks.get_task(task_id)
|
||||
dataset_id = task.dataset_id
|
||||
dataset = openml.datasets.get_dataset(dataset_id)
|
||||
X, y, categorical_indicator, _ = dataset.get_data(
|
||||
dataset_format="dataframe",
|
||||
target=task.target_name
|
||||
)
|
||||
# drop rows with NA target if any
|
||||
if isinstance(y, pd.Series):
|
||||
mask = ~y.isna()
|
||||
X, y = X.loc[mask], y.loc[mask]
|
||||
return X, y
|
||||
X, y, _, _ = dataset.get_data(dataset_format="dataframe", target=task.target_name)
|
||||
mask = ~y.isna()
|
||||
return X.loc[mask], y.loc[mask]
|
||||
|
||||
def load_dataset(name: str):
|
||||
if name in TASKS:
|
||||
X, y = _load_task_dataframe(TASKS[name])
|
||||
return X, y, "classification"
|
||||
task_type = "classification"
|
||||
elif name in DATASETS:
|
||||
ds_id = DATASETS[name]
|
||||
ds = openml.datasets.get_dataset(ds_id)
|
||||
X, y, categorical_indicator, _ = ds.get_data(
|
||||
dataset_format="dataframe", target=ds.default_target_attribute
|
||||
)
|
||||
mask = ~y.isna()
|
||||
return X.loc[mask], y.loc[mask], "regression"
|
||||
target_col = ds.default_target_attribute
|
||||
X, y, _, _ = ds.get_data(dataset_format="dataframe", target=None)
|
||||
mask = ~X[target_col].isna()
|
||||
X = X.loc[mask]
|
||||
y = X[target_col].loc[mask]
|
||||
task_type = "regression"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {name}")
|
||||
return X, y, task_type
|
||||
|
||||
# --- Interactive main ---
|
||||
def main():
|
||||
print("Available datasets:")
|
||||
all_datasets = list(TASKS.keys()) + list(DATASETS.keys())
|
||||
for i, name in enumerate(all_datasets):
|
||||
print(f"{i+1}. {name}")
|
||||
|
||||
selection = input("Enter the dataset name: ").strip()
|
||||
if selection not in all_datasets:
|
||||
raise ValueError(f"Dataset '{selection}' not recognized.")
|
||||
|
||||
X, y, task_type = load_dataset(selection)
|
||||
|
||||
# --- Identify default target ---
|
||||
default_target = y.name
|
||||
print(f"\nDefault target column: {default_target}")
|
||||
|
||||
# --- Print all features (without explanations) ---
|
||||
print("\nFeatures in the dataset:")
|
||||
for col in X.columns.unique():
|
||||
print(f"- {col} ({X[col].dtype})")
|
||||
|
||||
# --- Target selection ---
|
||||
target = input("\nEnter the target feature (or press Enter to use default): ").strip()
|
||||
if target:
|
||||
if target not in X.columns:
|
||||
raise ValueError(f"Target feature '{target}' not found.")
|
||||
y = X[target]
|
||||
X = X.drop(columns=[target], errors="ignore")
|
||||
else:
|
||||
target = default_target
|
||||
X = X.drop(columns=[target], errors="ignore")
|
||||
|
||||
# --- Feature exclusion ---
|
||||
exclude_input = input("\nEnter features to exclude (comma-separated), or press Enter to skip: ").strip()
|
||||
if exclude_input:
|
||||
exclude_cols = [col.strip() for col in exclude_input.split(",")]
|
||||
for col in exclude_cols:
|
||||
if col in X.columns:
|
||||
X = X.drop(columns=[col])
|
||||
else:
|
||||
print(f"Warning: '{col}' not found in dataset and cannot be excluded.")
|
||||
|
||||
# --- Show preview ---
|
||||
print("\nFinal dataset preview (first 5 rows):")
|
||||
print(X.head())
|
||||
print("\nTarget preview (first 5 rows):")
|
||||
print(y.head())
|
||||
print(f"\nTask type: {task_type}")
|
||||
print(f"Target column: {target}")
|
||||
print(f"Number of features: {len(X.columns)}")
|
||||
|
||||
# --- Export to CSV ---
|
||||
output_file = input("\nEnter filename to save dataset as CSV (e.g., dataset.csv): ").strip()
|
||||
if output_file:
|
||||
df_export = X.copy()
|
||||
df_export[target] = y # append target at the end
|
||||
df_export.to_csv(output_file, index=False)
|
||||
print(f"Dataset saved to {output_file} (target column: '{target}')")
|
||||
|
||||
# Save the CSV path to a temporary text file in the current directory
|
||||
temp_path_file = "last_csv_path.txt"
|
||||
full_path = os.path.abspath(output_file)
|
||||
with open(temp_path_file, "w") as f:
|
||||
f.write(full_path)
|
||||
print(f"CSV path written to {temp_path_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user