import { ConditionVariable } from "./condition-variable";

const defaultChannelBackpressure = 512;

export class AsyncIterableChannel<T> {
    private readonly dataBuffer: T[] = [];
    private finished: boolean = false;
    private error: Error | undefined;
    private readonly dataCV = new ConditionVariable();

    private readonly onCloseHandlers = new Set<(err: Error | undefined) => void>();

    constructor(public readonly backpressure: number = defaultChannelBackpressure) {}

    public get isClosed(): boolean {
        return this.finished;
    }

    public async send(data: T): Promise<void> {
        while (this.backpressure > 0 && this.dataBuffer.length >= this.backpressure && !this.finished) {
            await this.dataCV.wait();
        }
        if (this.finished) return;
        this.dataBuffer.push(data);
        this.dataCV.notifyAll();
    }

    public onClose(handler: (err: Error | undefined) => void): AsyncIterableChannel<T> {
        this.onCloseHandlers.add(handler);
        return this;
    }

    public close(err?: Error): void {
        if (this.finished) return;

        this.finished = true;
        this.error = err;
        this.dataCV.notifyAll();
        for (const handler of this.onCloseHandlers) {
            handler(this.error);
        }
    }

    public async *[Symbol.asyncIterator]() {
        while (!this.finished || this.dataBuffer.length > 0) {
            while (!this.finished && this.dataBuffer.length === 0) {
                await this.dataCV.wait();
            }
            if (this.dataBuffer.length > 0) {
                const next = this.dataBuffer.shift();
                this.dataCV.notifyAll();
                if (next !== undefined) {
                    yield next;
                }
            }
        }
        if (this.error !== undefined) {
            throw this.error;
        }
    }

    public getIterator() {
        return this[Symbol.asyncIterator]();
    }
}

export function batchedAsyncIterableChannel<T>(
    basis: AsyncIterableChannel<T>,
    batchSize: number,
    backpressure: number = 1
): AsyncIterableChannel<T[]> {
    const output = new AsyncIterableChannel<T[]>(backpressure);
    async function batch() {
        try {
            let itemBatch: T[] = [];
            for await (const item of basis) {
                itemBatch.push(item);
                if (itemBatch.length >= batchSize) {
                    if (output.isClosed) {
                        basis.close();
                        return;
                    }
                    await output.send(itemBatch);
                    itemBatch = [];
                }
            }
            if (itemBatch.length > 0 && !output.isClosed) {
                await output.send(itemBatch);
            }
            output.close();
        } catch (e: unknown) {
            output.close(e as any);
        }
    }
    void batch();
    return output;
}
