diff options
author | David Timber <dxdt@dev.snart.me> | 2024-11-19 13:15:02 +0100 |
---|---|---|
committer | David Timber <dxdt@dev.snart.me> | 2024-11-19 13:15:02 +0100 |
commit | c01c6586bc1f79510688f35824c7049172063b58 (patch) | |
tree | 3627a47ffd83b846bd5bc76f626ff11c6816a56d /shitoutcode |
Diffstat (limited to 'shitoutcode')
-rw-r--r-- | shitoutcode/__init__.py | 86 | ||||
-rw-r--r-- | shitoutcode/__main__.py | 54 |
2 files changed, 140 insertions, 0 deletions
diff --git a/shitoutcode/__init__.py b/shitoutcode/__init__.py new file mode 100644 index 0000000..67da767 --- /dev/null +++ b/shitoutcode/__init__.py @@ -0,0 +1,86 @@ +import openai +import pyjson5 + + +class Env: + def __init__(self): + self.model = "gpt-4o" + self.output_tokens_min = 250 + self.output_tokens_max = 500 + self.temp: float = None + self.seed: int = None + self.max_tokens = 8000 + +class SourceFile: + def __init__(self): + self.name: str = None + self.category: str = None + self.contents: str = None + +class LLMAPIException (Exception): ... + +def __do_prompt (lang: str, extra_prompt: str, env: Env): + prompt = ''' +Write me a program source code written in %s. Write anything you'd like. +Use more than %d words, but no more than %d words in the code.''' % ( + lang, env.output_tokens_min, env.output_tokens_max) + if extra_prompt: + prompt = ' ' + extra_prompt + + messages = [ + { + "role": "system", + "content": "You're a helpful assistant that writes " + + "computer program of any kind" + }, + { "role": "user", "content": prompt } + ] + functions = [ + { + "name": "get_result", + "description": "Output the source code", + "parameters": { + "type": "object", + "properties": { + "cat": { + "type": "string", + "description": "one word describing program functions" + }, + "filename": { + "type": "string", + "description": "The source code file name" # TODO: strip unallowed characters in Unix fs + }, + "code": { + "type": "string", + "description": "The source code" + } + } + } + } + ] + + return openai.chat.completions.create( + model = env.model, + messages = messages, + functions = functions, + function_call = "auto", + temperature = env.temp, + seed = env.seed, + max_tokens = env.max_tokens + ) + +def gen_rand_srccode (lang: str, extra_prompt: str, env: Env) -> SourceFile | str: + rsp = __do_prompt(lang, extra_prompt, env) + match rsp.choices[0].finish_reason: + case 'stop': + return rsp.choices[0].message.content + case 'function_call': + choice = pyjson5.loads(rsp.choices[0].message.function_call.arguments) + + ret = SourceFile() + ret.name = choice["filename"] + ret.category = choice.get("cat") + ret.contents = choice["code"] + + return ret + case _: raise LLMAPIException(rsp.choices[0].finish_reason) diff --git a/shitoutcode/__main__.py b/shitoutcode/__main__.py new file mode 100644 index 0000000..09ca715 --- /dev/null +++ b/shitoutcode/__main__.py @@ -0,0 +1,54 @@ +import os +import random +import sys +from shitoutcode import Env, LLMAPIException, SourceFile, gen_rand_srccode + +ARGV0 = "shitoutcode" + +def loadEnv_from_env () -> Env: + def wrap_default (val: str, t: type, dv = None): + if val is None: + return dv + return t(val) + + ret = Env() + ret.model = wrap_default(os.getenv("LLM_MODEL"), str, "gpt-4o-mini") + ret.output_tokens_min = wrap_default(os.getenv("OUT_TOKENS_MIN"), int, ret.output_tokens_min) + ret.output_tokens_max = wrap_default(os.getenv("OUT_TOKENS_MAX"), int, ret.output_tokens_max) + ret.temp = wrap_default(os.getenv("LLM_TEMP"), float) + ret.seed = wrap_default(os.getenv("LLM_SEED"), int) + ret.max_tokens = wrap_default(os.getenv("LLM_SEED"), int, ret.max_tokens) + + return ret + +def getPathLine (x: SourceFile) -> str: + if x.category: + return x.category + os.path.sep + x.name + return x.name + +def pickRandLang () -> str: + thelist = [ + "C", + "C++", + "Javascript", + "Java" + ] + r = random.randint(0, len(thelist) - 1) + + return thelist[r] + +env = loadEnv_from_env() +lang = sys.argv[1] if len(sys.argv) > 1 else pickRandLang() +extra_prompt = sys.argv[2] if len(sys.argv) > 2 else None + +try: + result = gen_rand_srccode(lang, extra_prompt, env) + if result is SourceFile: + print(getPathLine(result)) + print() + print(result.contents) + else: + print(result) +except LLMAPIException as e: + sys.stderr.write(ARGV0 + ": model gave up: " + str(e) + os.linesep) + exit(1) |