bogui/lib/riverpods/openai.dart

50 lines
1.6 KiB
Dart
Raw Normal View History

2023-01-07 13:02:53 +01:00
import 'package:flutter/material.dart';
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:openai_gpt3_api/openai_gpt3_api.dart';
class OpenAI extends GPT3 {
2023-01-08 17:37:51 +01:00
OpenAI() : super(String.fromEnvironment('OPENAPI_SECRET_KEY'));
2023-01-07 13:02:53 +01:00
// final prompt = TextEditingController();
final prompt =
2023-01-08 03:46:34 +01:00
StateProvider<TextEditingController>((ref) => TextEditingController());
final isLoading = StateProvider<bool>((ref) => false);
final temperature = StateProvider<double>((ref) => 0.7);
2023-01-07 13:02:53 +01:00
Future<String> completionEasy(WidgetRef ref) async {
2023-01-08 03:46:34 +01:00
if (ref.read(isLoading) || ref.read(prompt).text.length < 2) return '';
ref.read(isLoading.notifier).state = true;
2023-01-07 13:02:53 +01:00
final anwser = await OpenAI().completion(ref.read(prompt).text,
2023-01-08 03:46:34 +01:00
maxTokens: 250,
engine: Engine.davinci3,
temperature: ref.read(temperature),
echo: false,
stream: false);
2023-01-07 13:02:53 +01:00
String anwserString = '';
for (final choice in anwser.choices) {
anwserString += choice.text;
}
ref.read(prompt).text += anwserString;
2023-01-08 03:46:34 +01:00
ref.read(isLoading.notifier).state = false;
ref.read(prompt).selection = TextSelection.fromPosition(
TextPosition(offset: ref.read(prompt).text.length));
2023-01-07 13:02:53 +01:00
return anwserString;
}
}
2023-01-08 03:46:34 +01:00
double truncateDouble(double val, int decimals) {
String valString = val.toString();
int dotIndex = valString.indexOf('.');
// not enough decimals
int totalDecimals = valString.length - dotIndex - 1;
if (totalDecimals < decimals) {
decimals = totalDecimals;
}
valString = valString.substring(0, dotIndex + decimals + 1);
return double.parse(valString);
2023-01-08 17:37:51 +01:00
}