Skip to content

Polars is hard when working with Decimal

def fetch_ohlcv(session: Session, ticker: str) -> pl.DataFrame:

    output_schema: dict[str, pl.DataType] = {
        "ticker": pl.Utf8,
        "date": pl.Date,
        "open": pl.Float64,
        "high": pl.Float64,
        "low": pl.Float64,
        "close": pl.Float64,
        "volume": pl.Int64,
        "last_updated_at": pl.Datetime,
    }
    output: pl.DataFrame = pl.DataFrame(data=[], schema=output_schema, orient="row")
    res = None
    try:
        res = session.execute(
            sa.text(r"""
            WITH element AS (
                SELECT
                    ticker, `date`, `open`, high, low, `close`, volume,
                    COALESCE(deleted_at, updated_at, created_at) AS last_updated_at,
                    ROW_NUMBER() OVER (PARTITION BY `ticker`, `date` ORDER BY volume DESC, COALESCE(deleted_at, updated_at, created_at) DESC) AS numb
                FROM
                    csh_ohlcv_eod
                WHERE
                    `status` = 'A'
                    AND `date` <= CAST(CONVERT_TZ(CURRENT_TIMESTAMP, 'GMT', '+7:00') AS DATE)
                    AND (`open` > 0 AND  `high` > 0 AND `low` > 0 AND `close` > 0 AND volume > 0)
                    AND ticker = :tick
            )
            SELECT ticker, `date`, `open`, high, low, `close`, volume, last_updated_at
            FROM element
            WHERE numb = 1
            """),
            execution_options={
                "schema_translate_map": {None: sa.quoted_name(__CONFIG__.CLOUD_SQL_DATABASE_NAME, quote=True)}
            },
            params={
                "tick": ticker
            }
        )
    except Exception as exc:
        raise Exception(f"Fetch the OHLCV of security {ticker}") from exc
    else:
        # `polars` can't not handle the Decimal from sqlalchemy (at polars: , sqlalchemy)
        # This below functional will failed, even already handle the output
        # >>> output = pl.DataFrame(
        #     data=[x._asdict() for x in res.all()],
        #     schema={"ticker": pl.Utf8, "date": pl.Date, "open": pl.Float64, "high": pl.Float64, "low": pl.Float64, "close": pl.Float64, "volume": pl.Int64},
        #     orient="row"
        # )
        # So, we will stict with `pandas` which yield the true case and wrap back to polars, which is more sustainable
        element = [x._asdict() for x in res.all()]
        if len(element) > 0:
            _prev = pd.DataFrame(data=element)
            output = pl.DataFrame._from_pandas(data=_prev)
            for col, _type in output_schema.items():
                output = output.with_columns(pl.col(col).cast(_type).name.keep())

    return output