Compartilhar via


DataSourceArrowWriter

Uma classe base para gravadores de fonte de dados que processam dados usando pyArrow's RecordBatch.

Ao contrário DataSourceWriterde , que funciona com um iterador de objetos Spark Row , essa classe é otimizada para o formato de Seta ao gravar dados. Ele pode oferecer melhor desempenho ao interagir com sistemas ou bibliotecas que dão suporte nativo à Seta. Implemente essa classe e retorne uma instância para tornar uma fonte de DataSource.writer() dados gravável usando Seta.

Sintaxe

from pyspark.sql.datasource import DataSourceArrowWriter

class MyDataSourceArrowWriter(DataSourceArrowWriter):
    def write(self, iterator):
        ...

Methods

Método Descrição
write(iterator) Grava um iterador de objetos PyArrow RecordBatch no coletor. Chamado uma vez em cada executor. Retorna um WriterCommitMessageou None se não houver nenhuma mensagem de confirmação. Esse método é abstrato e deve ser implementado.
commit(messages) Confirma o trabalho de gravação usando uma lista de mensagens de confirmação coletadas de todos os executores. Invocado no driver quando todas as tarefas são executadas com êxito. Herdado de DataSourceWriter.
abort(messages) Anula o trabalho de gravação usando uma lista de mensagens de confirmação coletadas de todos os executores. Invocado no driver quando uma ou mais tarefas falharam. Herdado de DataSourceWriter.

Observações

  • O driver coleta mensagens de confirmação de todos os executores e as passa para commit() se todas as tarefas tiverem êxito ou se abort() alguma tarefa falhar.
  • Se uma tarefa de gravação falhar, sua mensagem de confirmação estará None na lista passada para commit() ou abort().

Exemplos

Implemente um gravador baseado em seta que conta linhas em todos os lotes:

from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
    num_rows: int

class MyDataSourceArrowWriter(DataSourceArrowWriter):
    def write(self, iterator):
        total_rows = 0
        for batch in iterator:
            total_rows += len(batch)
        return MyCommitMessage(num_rows=total_rows)

    def commit(self, messages):
        total = sum(m.num_rows for m in messages if m is not None)
        print(f"Committed {total} rows")

    def abort(self, messages):
        print("Write job failed, performing cleanup")