Source code for scystream.sdk.database_handling.postgres_manager

from pyspark.sql import SparkSession, DataFrame
from pydantic import BaseModel
from scystream.sdk.env.settings import PostgresSettings


[docs] class PostgresConfig(BaseModel): """ Configuration class for PostgreSQL connection details. This class holds the necessary configuration parameters to connect to a PostgreSQL database. It includes the database user, password, host, and port. :param PG_USER: The username for the PostgreSQL database. :param PG_PASS: The password for the PostgreSQL database. :param PG_HOST: The host address of the PostgreSQL server. :param PG_PORT: The port number of the PostgreSQL server. """ PG_USER: str PG_PASS: str PG_HOST: str PG_PORT: int
[docs] class PostgresOperations(): """ Class to perform PostgreSQL operations using Apache Spark. This class provides methods to read from and write to a PostgreSQL database using JDBC and Spark's DataFrame API. It requires a SparkSession and a PostgresConfig object or the PostgresSettings from an input or output for database connectivity. """
[docs] def __init__( self, spark: SparkSession, config: PostgresConfig | PostgresSettings ): self.spark_session = spark self.jdbc_url = \ f"jdbc:postgresql://{config.PG_HOST}:{config.PG_PORT}" self.properties = { "user": config.PG_USER, "password": config.PG_PASS, "driver": "org.postgresql.Driver" }
[docs] def read( self, database_name: str, table: str = None, query: str = None ) -> DataFrame: """ Reads data from a PostgreSQL database into a Spark DataFrame. This method can either read data from a specified table or execute a custom SQL query to retrieve data from the database. :param database_name: The name of the database to connect to. :param table: The name of the table to read data from. Must be provided if `query` is not supplied. (optional) :param query: A custom SQL query to run. If provided, this overrides the `table` parameter. (optional) :raises ValueError: If neither `table` nor `query` is provided. :return: A Spark DataFrame containing the result of the query or table data. :rtype: DataFrame """ if not table and not query: raise ValueError("Either 'table' or 'query' must be provided.") db_url = f"{self.jdbc_url}/{database_name}" dbtable_option = f"({query}) AS subquery" if query else table return self.spark_session.read \ .format("jdbc") \ .option("url", db_url) \ .option("dbtable", dbtable_option) \ .options(**self.properties) \ .load()
[docs] def write( self, database_name: str, table: str, dataframe, mode="overwrite" ): """ Writes a Spark DataFrame to a specified table in a PostgreSQL database using JDBC. This method writes the provided DataFrame to the target PostgreSQL table, with the option to specify the write mode (overwrite, append, etc.). :param database_name: The name of the database to connect to. :param table: The name of the table where data will be written. :param dataframe: The Spark DataFrame containing the data to write. :param mode: The write mode. Valid options are 'overwrite', 'append', 'ignore', and 'error'. Defaults to 'overwrite'. (optional) :note: Ensure that the schema of the DataFrame matches the schema of the target table if the table exists. :note: The `mode` parameter controls the behavior when the table already exists. """ db_url = f"{self.jdbc_url}/{database_name}" dataframe.write.format("jdbc")\ .option("url", db_url) \ .option("dbtable", table) \ .options(**self.properties) \ .mode(mode) \ .save()