from datetime import datetime
from dataclasses import dataclass, field
from typing import Callable, List
import pandas as pd
from cowidev.utils.s3 import obj_from_s3
[docs]@dataclass
class Grapheriser:
location: str = "location"
date: str = "date"
date_ref: datetime = datetime(2020, 1, 21)
fillna: bool = False
fillna_0: bool = True
pivot_column: str = None
pivot_values: str = None
suffixes: list = None
function_input: Callable = lambda x: x
function_output: Callable = lambda x: x
columns_non_fillna_0: list = field(default_factory=lambda:[])
@property
def columns_metadata(self) -> list:
return ["Country", "Year"]
@property
def metric2suffix(self) -> dict:
if len(self.pivot_values_list) == len(self.suffixes_list):
return dict(zip(self.pivot_values_list, self.suffixes_list))
else:
raise ValueError(f"`suffixes` and `pivot_values` should be lists of the same length")
@property
def pivot_values_list(self) -> list:
if not isinstance(self.pivot_values, list):
return [self.pivot_values]
return self.pivot_values
@property
def suffixes_list(self) -> list:
suffizes = self.suffixes
if not isinstance(self.suffixes, list):
suffizes = [self.suffixes]
return ["" if s is None else s for s in suffizes]
@property
def do_pivot(self):
return self.pivot_column is not None and self.pivot_values is not None
[docs] def columns_data(self, df: pd.DataFrame) -> list:
return [col for col in df.columns if col not in self.columns_metadata]
[docs] def pipe_pivot(self, df: pd.DataFrame) -> pd.DataFrame:
"""Pivot values of columns of interest."""
if self.do_pivot:
return df.pivot(
index=[self.location, self.date],
columns=self.pivot_column,
values=self.pivot_values_list,
).reset_index()
return df
[docs] def pipe_normalize_columns(self, df):
"""Normalize column names.
If columns are multiindex (of length 2), use first and second positions to create new column name.
This only applies if pivot has been done, i.e. `pivot_column` and `pivot_values` are not None.
"""
def _normalize_column(column):
if len(column) != 2:
raise ValueError("Column is expected to have length 2")
if column[1]:
column_new = f"{column[1]}{self.metric2suffix.get(column[0], '')}"
else:
column_new = column[0]
return column_new
if self.do_pivot:
df.columns = [_normalize_column(xx) for xx in df.columns]
return df
[docs] def pipe_order_columns(self, df: pd.DataFrame) -> pd.DataFrame:
"""Re-order the columns of the dataframe.
First columns are [Country, Year]
"""
col_order = self.columns_metadata + self.columns_data(df)
df = df[col_order].sort_values(col_order)
return df
[docs] def pipe_fillna(self, df: pd.DataFrame) -> pd.DataFrame:
columns_data = self.columns_data(df)
if self.fillna:
df[columns_data] = df.groupby(["Country"])[columns_data].fillna(method="ffill")
if self.fillna_0:
cols_fillna0 = [c for c in columns_data if c not in self.columns_non_fillna_0]
df[cols_fillna0] = df[cols_fillna0].fillna(0)
return df
[docs] def read(self, input_path: str):
if input_path.startswith("s3://"):
return obj_from_s3(input_path, parse_dates=[self.date])
return pd.read_csv(input_path, parse_dates=[self.date])
[docs] def pipeline(self, df: pd.DataFrame):
df = (
df.pipe(self.function_input)
.pipe(self.pipe_pivot)
.pipe(self.pipe_metadata_columns)
.pipe(self.pipe_normalize_columns)
.pipe(self.pipe_order_columns)
.pipe(self.pipe_fillna)
.pipe(self.function_output)
)
return df
[docs] def run(self, input_path: str, output_path: str):
df = self.read(input_path)
df = df.pipe(self.pipeline)
df.to_csv(output_path, index=False)