Dela via


DataSourceArrowWriter

En basklass för datakällskrivare som bearbetar data med PyArrows RecordBatch.

Till skillnad från DataSourceWriter, som fungerar med en iterator av Spark-objekt Row , optimeras den här klassen för pilformatet när du skriver data. Det kan ge bättre prestanda när du interagerar med system eller bibliotek som har inbyggt stöd för Arrow. Implementera den här klassen och returnera en instans från DataSource.writer() för att göra en datakälla skrivbar med pilen.

Syntax

from pyspark.sql.datasource import DataSourceArrowWriter

class MyDataSourceArrowWriter(DataSourceArrowWriter):
    def write(self, iterator):
        ...

Methods

Metod Beskrivning
write(iterator) Skriver en iterator av PyArrow-objekt RecordBatch till mottagaren. Anropas en gång på varje köre. Returnerar ett WriterCommitMessage, eller None om det inte finns något incheckningsmeddelande. Den här metoden är abstrakt och måste implementeras.
commit(messages) Genomför skrivjobbet med hjälp av en lista över incheckningsmeddelanden som samlats in från alla utförare. Anropas på drivrutinen när alla aktiviteter körs. Ärvd från DataSourceWriter.
abort(messages) Avbryter skrivjobbet med hjälp av en lista över incheckningsmeddelanden som samlats in från alla utförare. Anropas på drivrutinen när en eller flera uppgifter misslyckades. Ärvd från DataSourceWriter.

Notes

  • Drivrutinen samlar in incheckningsmeddelanden från alla utförare och skickar dem till commit() om alla uppgifter lyckas, eller till abort() om någon uppgift misslyckas.
  • Om en skrivuppgift misslyckas finns dess incheckningsmeddelande i listan som skickas None till commit() eller abort().

Exempel

Implementera en pilbaserad skrivare som räknar rader i alla batchar:

from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
    num_rows: int

class MyDataSourceArrowWriter(DataSourceArrowWriter):
    def write(self, iterator):
        total_rows = 0
        for batch in iterator:
            total_rows += len(batch)
        return MyCommitMessage(num_rows=total_rows)

    def commit(self, messages):
        total = sum(m.num_rows for m in messages if m is not None)
        print(f"Committed {total} rows")

    def abort(self, messages):
        print("Write job failed, performing cleanup")